Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/scope/core/pipelines/wan2_1/vace/blocks/vace_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,23 @@ def _encode_with_conditioning(self, components, block_state, current_start):
input_masks_data.shape
)
if mask_channels != 1:
raise ValueError(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_masks must have 1 channel, got {mask_channels}"
)
if mask_channels == 3:
# Depth maps from video-depth-anything arrive as 3-channel RGB
# (depth value replicated across R/G/B). Convert to single-channel
# grayscale by averaging so downstream VACE encoding works correctly.
import logging
logging.getLogger(__name__).warning(
"VaceEncodingBlock._encode_with_conditioning: vace_input_masks has 3 "
"channels (likely an RGB depth map). Auto-converting to single-channel "
"by averaging. Wire a grayscale source to avoid this conversion."
)
# Shape: [B, 3, F, H, W] -> [B, 1, F, H, W]
input_masks_data = input_masks_data.mean(dim=1, keepdim=True)
mask_channels = 1
else:
raise ValueError(
f"VaceEncodingBlock._encode_with_conditioning: vace_input_masks must have 1 channel, got {mask_channels}"
)
if (
mask_frames != num_frames
or mask_height != height
Expand Down
64 changes: 61 additions & 3 deletions src/scope/server/graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@

from __future__ import annotations

from typing import Literal
import logging
from typing import Any, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

logger = logging.getLogger(__name__)


class GraphNode(BaseModel):
Expand Down Expand Up @@ -72,7 +75,13 @@ class GraphNode(BaseModel):


class GraphEdge(BaseModel):
"""An edge connecting an output port to an input port."""
"""An edge connecting an output port to an input port.

Accepts both the current schema (``from``, ``from_port``, ``to_node``,
``to_port``) and the legacy schema (``source``, ``target``) for backwards
compatibility with older Scope desktop clients. When the legacy keys are
present the port names default to ``"video"``.
"""

from_node: str = Field(..., alias="from", description="Source node id")
from_port: str = Field(
Expand All @@ -87,6 +96,55 @@ class GraphEdge(BaseModel):

model_config = {"populate_by_name": True}

@model_validator(mode="before")
@classmethod
def _coerce_legacy_edge(cls, data: Any) -> Any:
"""Map legacy ``source``/``target`` keys to the current schema.

Older clients send edges as::

{"source": "input", "target": "pipeline"}

The current schema requires ``from``, ``from_port``, ``to_node``,
``to_port``. This validator accepts any mix of legacy and current keys,
mapping them where the canonical field is absent. Port names default to
``"video"`` when the legacy payload omits port information.
"""
if not isinstance(data, dict):
return data

has_legacy = "source" in data or "target" in data
if not has_legacy:
return data

logger.warning(
"GraphEdge: received legacy edge schema (source/target). "
"Please update the Scope client to send 'from'/'to_node' edges. "
"Coercing automatically. source=%r target=%r",
data.get("source"),
data.get("target"),
)

data = dict(data) # make a mutable copy

# Map source → from (only when 'from' is absent)
if "source" in data and "from" not in data:
data["from"] = data.pop("source")
else:
data.pop("source", None)

# Map target → to_node (only when 'to_node' is absent)
if "target" in data and "to_node" not in data:
data["to_node"] = data.pop("target")
else:
data.pop("target", None)

# Apply port defaults when the caller omitted them
data.setdefault("from_port", data.pop("source_port", "video"))
data.setdefault("to_port", data.pop("target_port", "video"))

return data


class GraphConfig(BaseModel):
"""Root graph configuration (graph definition)."""
Expand Down
159 changes: 159 additions & 0 deletions tests/test_graph_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Tests for graph_schema backwards-compatibility (issue #895).

Verifies that GraphEdge and GraphConfig accept both the legacy
``source``/``target`` edge format and the current ``from``/``to_node`` format.
"""

from __future__ import annotations

import logging

import pytest

from scope.server.graph_schema import GraphConfig, GraphEdge, GraphNode


# ---------------------------------------------------------------------------
# GraphEdge unit tests
# ---------------------------------------------------------------------------


class TestGraphEdgeLegacyKeys:
"""GraphEdge should accept the old source/target format."""

def test_legacy_source_target_minimal(self):
"""Basic source/target without ports → defaults applied."""
edge = GraphEdge.model_validate({"source": "input", "target": "pipeline"})
assert edge.from_node == "input"
assert edge.to_node == "pipeline"
assert edge.from_port == "video" # default
assert edge.to_port == "video" # default
assert edge.kind == "stream" # default

def test_legacy_source_target_with_ports(self):
"""Legacy source/target alongside explicit port names."""
edge = GraphEdge.model_validate(
{
"source": "input",
"target": "pipeline",
"from_port": "video",
"to_port": "video",
}
)
assert edge.from_node == "input"
assert edge.to_node == "pipeline"
assert edge.from_port == "video"
assert edge.to_port == "video"

def test_legacy_emits_deprecation_warning(self, caplog):
with caplog.at_level(logging.WARNING, logger="scope.server.graph_schema"):
GraphEdge.model_validate({"source": "a", "target": "b"})
assert any("legacy edge schema" in r.message for r in caplog.records)

def test_legacy_only_source(self):
"""Only 'source' provided (no 'target') — should still parse."""
edge = GraphEdge.model_validate(
{"source": "a", "to_node": "b", "from_port": "video", "to_port": "video"}
)
assert edge.from_node == "a"
assert edge.to_node == "b"

def test_legacy_only_target(self):
"""Only 'target' provided (no 'source') — should still parse."""
edge = GraphEdge.model_validate(
{"from": "a", "target": "b", "from_port": "video", "to_port": "video"}
)
assert edge.from_node == "a"
assert edge.to_node == "b"


class TestGraphEdgeCurrentKeys:
"""Existing schema (from/from_port/to_node/to_port) must still work."""

def test_current_format(self):
edge = GraphEdge.model_validate(
{
"from": "input",
"from_port": "video",
"to_node": "pipeline",
"to_port": "video",
"kind": "stream",
}
)
assert edge.from_node == "input"
assert edge.from_port == "video"
assert edge.to_node == "pipeline"
assert edge.to_port == "video"
assert edge.kind == "stream"

def test_current_format_no_warning(self, caplog):
with caplog.at_level(logging.WARNING, logger="scope.server.graph_schema"):
GraphEdge.model_validate(
{
"from": "input",
"from_port": "video",
"to_node": "pipeline",
"to_port": "video",
}
)
assert not any("Deprecated" in r.message for r in caplog.records)


# ---------------------------------------------------------------------------
# GraphConfig integration test
# ---------------------------------------------------------------------------


class TestGraphConfigLegacyEdges:
"""GraphConfig should parse correctly even when edges use legacy keys."""

def _make_config(self, edges):
return GraphConfig.model_validate(
{
"nodes": [
{"id": "input", "type": "source"},
{"id": "pipeline", "type": "pipeline", "pipeline_id": "my_pipe"},
{"id": "output", "type": "sink"},
],
"edges": edges,
}
)

def test_legacy_edges_in_graph_config(self):
cfg = self._make_config(
[
{"source": "input", "target": "pipeline"},
{"source": "pipeline", "target": "output"},
]
)
assert len(cfg.edges) == 2
assert cfg.edges[0].from_node == "input"
assert cfg.edges[0].to_node == "pipeline"
assert cfg.edges[1].from_node == "pipeline"
assert cfg.edges[1].to_node == "output"

def test_mixed_edges_in_graph_config(self):
"""Mix of legacy and current edge formats in the same config."""
cfg = self._make_config(
[
{"source": "input", "target": "pipeline"},
{
"from": "pipeline",
"from_port": "video",
"to_node": "output",
"to_port": "video",
},
]
)
assert cfg.edges[0].from_node == "input"
assert cfg.edges[1].from_node == "pipeline"

def test_validate_structure_passes(self):
cfg = self._make_config(
[
{"source": "input", "target": "pipeline"},
{"source": "pipeline", "target": "output"},
]
)
errors = cfg.validate_structure()
assert errors == []
Loading