Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions src/scope/core/pipelines/wan2_1/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,93 @@ def normalize_lora_key(lora_base_key: str) -> str:
return lora_base_key


def _wait_for_lora_file(
lora_path: str,
timeout: float | None = None,
poll_interval: float = 2.0,
) -> bool:
"""Poll until *lora_path* exists on disk or *timeout* is exceeded.

Handles the race condition where a LoRA file is being downloaded from a
remote source (e.g. Civitai) concurrently with pipeline initialisation.
The pipeline calls ``load_lora_weights`` synchronously while the download
runs in a separate thread; without the poll the load fails with
``FileNotFoundError`` even though the file will be available shortly.

Args:
lora_path: Absolute path of the LoRA file to wait for.
timeout: Maximum seconds to wait. Defaults to the value of the
``SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT`` environment variable
(default: 120 s). Pass 0 to disable waiting entirely.
poll_interval: Seconds between existence checks (default 2 s).

Returns:
``True`` if the file appeared within the timeout, ``False`` otherwise.
"""
import time

if timeout is None:
timeout = float(os.getenv("SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT", "120"))

if timeout <= 0 or os.path.exists(lora_path):
return os.path.exists(lora_path)

logger.info(
"_wait_for_lora_file: '%s' not yet present — waiting up to %.0fs for "
"in-flight download to complete (poll every %.1fs)",
lora_path,
timeout,
poll_interval,
)

deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
time.sleep(poll_interval)
if os.path.exists(lora_path):
waited = timeout - (deadline - time.monotonic())
logger.info(
"_wait_for_lora_file: '%s' appeared after %.1fs",
lora_path,
waited,
)
return True

return False


def load_lora_weights(lora_path: str) -> dict[str, torch.Tensor]:
"""
Load LoRA weights from .safetensors or .bin file.

If the file does not exist immediately, this function will poll for up to
``SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT`` seconds (default: 120) to allow an
in-flight Civitai/HuggingFace download to complete before raising. This
prevents spurious ``FileNotFoundError`` failures during session reinit when
a LoRA asset download races with pipeline ``__init__``.

Args:
lora_path: Path to LoRA file (.safetensors or .bin)

Returns:
Dictionary mapping parameter names to tensors

Raises:
FileNotFoundError: If the LoRA file does not exist
FileNotFoundError: If the LoRA file does not exist (and did not appear
within the configured wait timeout).
"""
if not os.path.exists(lora_path):
raise FileNotFoundError(f"load_lora_weights: LoRA file not found: {lora_path}")
if not _wait_for_lora_file(lora_path):
raise FileNotFoundError(
f"load_lora_weights: LoRA file not found: {lora_path}"
)

if lora_path.endswith(".safetensors"):
return load_file(lora_path)
else:
return torch.load(lora_path, map_location="cpu")



def find_lora_pair(
lora_key: str, lora_state: dict[str, torch.Tensor]
) -> tuple[str, str, torch.Tensor, torch.Tensor] | None:
Expand Down
43 changes: 42 additions & 1 deletion src/scope/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ class PipelineNotAvailableException(Exception):
pass


class PipelineNotYetRegisteredException(ValueError):
"""Exception raised when a pipeline ID is not in the registry yet.

This is a *transient* error — it typically occurs during cloud session
initialization when the frontend concurrently requests a plugin install
and a pipeline load. The pipeline load may arrive before the plugin has
finished installing and registering itself, so the registry lookup returns
``None`` even though the pipeline ID will eventually become valid.

Callers should treat this as a retriable condition rather than a hard
error.
"""

pass


class PipelineStatus(Enum):
"""Pipeline loading status enumeration."""

Expand Down Expand Up @@ -336,6 +352,29 @@ def _load_pipeline_by_id_sync(
)
return True

except PipelineNotYetRegisteredException:
# Transient race condition: the pipeline plugin hasn't finished
# installing yet. Log at WARN (not ERROR) and leave the status as
# NOT_LOADED so the frontend doesn't show an error state and the
# load can be retried transparently once the plugin is registered.
self.set_loading_stage(None)
logger.warning(
f"Pipeline '{key}' is not registered — the plugin may still be "
f"installing. This is likely a transient race condition and will "
f"resolve once the plugin is installed."
)
with self._lock:
self._pipeline_statuses[key] = PipelineStatus.NOT_LOADED
if key in self._pipelines:
del self._pipelines[key]
if key in self._pipeline_load_params:
del self._pipeline_load_params[key]
if key in self._pipeline_registry_ids:
del self._pipeline_registry_ids[key]
if key in self._load_events:
self._load_events[key].set()
return False

except Exception as e:
self.set_loading_stage(None)
from .models_config import get_models_dir
Expand Down Expand Up @@ -1385,7 +1424,9 @@ def _load_pipeline_implementation(
logger.info("OpticalFlow pipeline initialized")
return pipeline
else:
raise ValueError(f"Invalid pipeline ID: {pipeline_id}")
raise PipelineNotYetRegisteredException(
f"Invalid pipeline ID: {pipeline_id}. Plugin may not be installed yet."
)

def is_loaded(self) -> bool:
"""Check if pipeline is loaded and ready (thread-safe)."""
Expand Down
131 changes: 131 additions & 0 deletions tests/test_lora_wait_for_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Tests for the LoRA download-wait helper in lora/utils.py.

Covers the race condition where a Civitai LoRA file is still being
downloaded when LongLivePipeline.__init__ calls load_lora_weights.
See: daydreamlive/scope#937
"""

import os
import threading
import time
from pathlib import Path
from unittest.mock import patch

import pytest

from scope.core.pipelines.wan2_1.lora.utils import _wait_for_lora_file


class TestWaitForLoraFile:
"""Unit tests for _wait_for_lora_file."""

def test_file_already_present_returns_immediately(self, tmp_path: Path):
"""If the file exists before the first check, return True right away."""
lora_file = tmp_path / "model.safetensors"
lora_file.touch()

start = time.monotonic()
result = _wait_for_lora_file(str(lora_file), timeout=10, poll_interval=0.1)
elapsed = time.monotonic() - start

assert result is True
# Should not have slept at all
assert elapsed < 0.5

def test_file_appears_during_wait(self, tmp_path: Path):
"""File appears mid-poll; function returns True after ≤2 poll intervals."""
lora_file = tmp_path / "late.safetensors"

def _create_later():
time.sleep(0.3)
lora_file.touch()

t = threading.Thread(target=_create_later, daemon=True)
t.start()

result = _wait_for_lora_file(str(lora_file), timeout=5, poll_interval=0.1)
t.join()

assert result is True

def test_file_never_appears_returns_false(self, tmp_path: Path):
"""File never shows up; function returns False after timeout."""
missing = str(tmp_path / "missing.safetensors")

result = _wait_for_lora_file(missing, timeout=0.3, poll_interval=0.1)

assert result is False

def test_timeout_zero_disables_wait(self, tmp_path: Path):
"""timeout=0 means skip the poll entirely; missing file → False instantly."""
missing = str(tmp_path / "no_wait.safetensors")

start = time.monotonic()
result = _wait_for_lora_file(missing, timeout=0, poll_interval=0.1)
elapsed = time.monotonic() - start

assert result is False
assert elapsed < 0.1

def test_env_var_overrides_default_timeout(self, tmp_path: Path, monkeypatch):
"""SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT env var controls the default timeout."""
missing = str(tmp_path / "env_timeout.safetensors")
monkeypatch.setenv("SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT", "0.2")

start = time.monotonic()
# Pass timeout=None so env var is picked up
result = _wait_for_lora_file(missing, timeout=None, poll_interval=0.05)
elapsed = time.monotonic() - start

assert result is False
# Should respect the 0.2 s limit (allow generous buffer for CI)
assert elapsed < 1.5


class TestLoadLoraWeightsWaits:
"""Integration-style tests ensuring load_lora_weights uses the poll helper."""

def test_raises_after_timeout_when_file_never_appears(self, tmp_path: Path):
"""load_lora_weights should raise FileNotFoundError when wait times out."""
from scope.core.pipelines.wan2_1.lora.utils import load_lora_weights

missing = str(tmp_path / "never_there.safetensors")
# Short timeout so the test stays fast
with patch.dict(os.environ, {"SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT": "0.2"}):
with pytest.raises(FileNotFoundError, match="LoRA file not found"):
load_lora_weights(missing)

def test_succeeds_when_file_appears_during_wait(self, tmp_path: Path):
"""load_lora_weights should succeed if the file arrives within the timeout.

We bypass load_lora_weights itself and test _wait_for_lora_file + the
safetensors load in combination, keeping the test fast by using a short
poll interval.
"""
import torch
from safetensors.torch import save_file

from scope.core.pipelines.wan2_1.lora.utils import (
_wait_for_lora_file,
load_lora_weights,
)

lora_file = tmp_path / "delayed.safetensors"

# Write a minimal safetensors file after a short delay
def _write_later():
time.sleep(0.3)
tensors = {"lora_A.weight": torch.zeros(4, 4)}
save_file(tensors, str(lora_file))

t = threading.Thread(target=_write_later, daemon=True)
t.start()

# Directly exercise _wait_for_lora_file with a short poll interval, then
# verify load_lora_weights can read the now-present file.
appeared = _wait_for_lora_file(str(lora_file), timeout=5, poll_interval=0.1)
t.join()

assert appeared is True
result = load_lora_weights(str(lora_file))
assert "lora_A.weight" in result
Loading
Loading