Improve ProjectLoaderDynamic UX: single node, error feedback, auto-refresh

Remove 3 redundant hardcoded nodes (Standard/VACE/LoRA), keeping only the
Dynamic node. Add total_sequences INT output (slot 0) for loop counting.
Add structured error handling: _fetch_json returns typed error dicts,
load_dynamic raises RuntimeError with descriptive messages, JS shows
red border/title on errors. Add 500ms debounced auto-refresh on widget
changes. Add 404s for missing project/file in API endpoints.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-28 22:16:08 +01:00
parent d07a308865
commit 4b5fff5c6e
6 changed files with 295 additions and 255 deletions

View File

@@ -55,13 +55,26 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
db = _get_db() db = _get_db()
data = db.query_sequence_data(name, file_name, seq) proj = db.get_project(name)
if not proj:
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
df = db.get_data_file_by_names(name, file_name)
if not df:
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
data = db.get_sequence(df["id"], seq)
if data is None: if data is None:
raise HTTPException(status_code=404, detail="Sequence not found") raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
return data return data
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
db = _get_db() db = _get_db()
keys, types = db.query_sequence_keys(name, file_name, seq) proj = db.get_project(name)
return {"keys": keys, "types": types} if not proj:
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
df = db.get_data_file_by_names(name, file_name)
if not df:
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
keys, types = db.get_sequence_keys(df["id"], seq)
total = db.count_sequences(df["id"])
return {"keys": keys, "types": types, "total_sequences": total}

15
db.py
View File

@@ -182,6 +182,21 @@ class ProjectDB:
).fetchall() ).fetchall()
return [r["sequence_number"] for r in rows] return [r["sequence_number"] for r in rows]
def count_sequences(self, data_file_id: int) -> int:
"""Return the number of sequences for a data file."""
row = self.conn.execute(
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
(data_file_id,),
).fetchone()
return row["cnt"]
def query_total_sequences(self, project_name: str, file_name: str) -> int:
"""Return total sequence count by project and file names."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return 0
return self.count_sequences(df["id"])
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
"""Returns (keys, types) for a sequence's data dict.""" """Returns (keys, types) for a sequence's data dict."""
data = self.get_sequence(data_file_id, sequence_number) data = self.get_sequence(data_file_id, sequence_number)

View File

@@ -39,13 +39,31 @@ def to_int(val: Any) -> int:
def _fetch_json(url: str) -> dict: def _fetch_json(url: str) -> dict:
"""Fetch JSON from a URL using stdlib urllib.""" """Fetch JSON from a URL using stdlib urllib.
On error, returns a dict with an "error" key describing the failure.
"""
try: try:
with urllib.request.urlopen(url, timeout=5) as resp: with urllib.request.urlopen(url, timeout=5) as resp:
return json.loads(resp.read()) return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, OSError) as e: except urllib.error.HTTPError as e:
logger.warning(f"Failed to fetch {url}: {e}") # HTTPError is a subclass of URLError — must be caught first
return {} body = ""
try:
raw = e.read()
detail = json.loads(raw)
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
except Exception:
body = str(e)
logger.warning(f"HTTP {e.code} from {url}: {body}")
return {"error": "http_error", "status": e.code, "message": body}
except (urllib.error.URLError, OSError) as e:
reason = str(e.reason) if hasattr(e, "reason") else str(e)
logger.warning(f"Network error fetching {url}: {reason}")
return {"error": "network_error", "message": reason}
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON from {url}: {e}")
return {"error": "parse_error", "message": str(e)}
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
@@ -100,6 +118,9 @@ if PromptServer is not None:
except (ValueError, TypeError): except (ValueError, TypeError):
seq = 1 seq = 1
data = _fetch_keys(manager_url, project, file_name, seq) data = _fetch_keys(manager_url, project, file_name, seq)
if data.get("error") in ("http_error", "network_error", "parse_error"):
status = data.get("status", 502)
return web.json_response(data, status=status)
return web.json_response(data) return web.json_response(data)
@@ -124,15 +145,25 @@ class ProjectLoaderDynamic:
}, },
} }
RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
FUNCTION = "load_dynamic" FUNCTION = "load_dynamic"
CATEGORY = "utils/json/project" CATEGORY = "utils/json/project"
OUTPUT_NODE = False OUTPUT_NODE = False
def load_dynamic(self, manager_url, project_name, file_name, sequence_number, def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
output_keys="", output_types=""): output_keys="", output_types=""):
# Fetch keys metadata (includes total_sequences count)
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
msg = keys_meta.get("message", "Unknown error")
raise RuntimeError(f"Failed to fetch project keys: {msg}")
total_sequences = keys_meta.get("total_sequences", 0)
data = _fetch_data(manager_url, project_name, file_name, sequence_number) data = _fetch_data(manager_url, project_name, file_name, sequence_number)
if data.get("error") in ("http_error", "network_error", "parse_error"):
msg = data.get("message", "Unknown error")
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
# Parse keys — try JSON array first, fall back to comma-split for compat # Parse keys — try JSON array first, fall back to comma-split for compat
keys = [] keys = []
@@ -171,111 +202,14 @@ class ProjectLoaderDynamic:
while len(results) < MAX_DYNAMIC_OUTPUTS: while len(results) < MAX_DYNAMIC_OUTPUTS:
results.append("") results.append("")
return tuple(results) return (total_sequences,) + tuple(results)
# ==========================================
# 1. STANDARD NODE (Project-based I2V)
# ==========================================
class ProjectLoaderStandard:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
FUNCTION = "load_standard"
CATEGORY = "utils/json/project"
def load_standard(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
str(data.get("current_prompt", "")), str(data.get("negative", "")),
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
)
# ==========================================
# 2. VACE NODE (Project-based)
# ==========================================
class ProjectLoaderVACE:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
FUNCTION = "load_vace"
CATEGORY = "utils/json/project"
def load_vace(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
str(data.get("current_prompt", "")), str(data.get("negative", "")),
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)),
to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)),
str(data.get("reference path", "")), to_int(data.get("reference switch", 1)),
to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")),
str(data.get("reference image path", ""))
)
# ==========================================
# 3. LoRA NODE (Project-based)
# ==========================================
class ProjectLoaderLoRA:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
FUNCTION = "load_loras"
CATEGORY = "utils/json/project"
def load_loras(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
)
# --- Mappings --- # --- Mappings ---
PROJECT_NODE_CLASS_MAPPINGS = { PROJECT_NODE_CLASS_MAPPINGS = {
"ProjectLoaderDynamic": ProjectLoaderDynamic, "ProjectLoaderDynamic": ProjectLoaderDynamic,
"ProjectLoaderStandard": ProjectLoaderStandard,
"ProjectLoaderVACE": ProjectLoaderVACE,
"ProjectLoaderLoRA": ProjectLoaderLoRA,
} }
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
"ProjectLoaderDynamic": "Project Loader (Dynamic)", "ProjectLoaderDynamic": "Project Loader (Dynamic)",
"ProjectLoaderStandard": "Project Loader (Standard/I2V)",
"ProjectLoaderVACE": "Project Loader (VACE Full)",
"ProjectLoaderLoRA": "Project Loader (LoRAs)",
} }

View File

@@ -172,6 +172,25 @@ class TestSequences:
db.delete_sequences_for_file(df_id) db.delete_sequences_for_file(df_id)
assert db.list_sequences(df_id) == [] assert db.list_sequences(df_id) == []
def test_count_sequences(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
assert db.count_sequences(df_id) == 0
db.upsert_sequence(df_id, 1, {"a": 1})
db.upsert_sequence(df_id, 2, {"b": 2})
db.upsert_sequence(df_id, 3, {"c": 3})
assert db.count_sequences(df_id) == 3
def test_query_total_sequences(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {"a": 1})
db.upsert_sequence(df_id, 2, {"b": 2})
assert db.query_total_sequences("p1", "batch") == 2
def test_query_total_sequences_nonexistent(self, db):
assert db.query_total_sequences("nope", "nope") == 0
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# History trees # History trees

View File

@@ -6,9 +6,6 @@ import pytest
from project_loader import ( from project_loader import (
ProjectLoaderDynamic, ProjectLoaderDynamic,
ProjectLoaderStandard,
ProjectLoaderVACE,
ProjectLoaderLoRA,
_fetch_json, _fetch_json,
_fetch_data, _fetch_data,
_fetch_keys, _fetch_keys,
@@ -32,11 +29,23 @@ class TestFetchHelpers:
result = _fetch_json("http://example.com/api") result = _fetch_json("http://example.com/api")
assert result == data assert result == data
def test_fetch_json_failure(self): def test_fetch_json_network_error(self):
import urllib.error
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
result = _fetch_json("http://example.com/api") result = _fetch_json("http://example.com/api")
assert result == {} assert result["error"] == "network_error"
assert "connection refused" in result["message"]
def test_fetch_json_http_error(self):
import urllib.error
err = urllib.error.HTTPError(
"http://example.com/api", 404, "Not Found", {},
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
)
with patch("project_loader.urllib.request.urlopen", side_effect=err):
result = _fetch_json("http://example.com/api")
assert result["error"] == "http_error"
assert result["status"] == 404
assert "not found" in result["message"].lower()
def test_fetch_data_builds_url(self): def test_fetch_data_builds_url(self):
data = {"prompt": "hello"} data = {"prompt": "hello"}
@@ -73,18 +82,23 @@ class TestFetchHelpers:
class TestProjectLoaderDynamic: class TestProjectLoaderDynamic:
def _keys_meta(self, total=5):
return {"keys": [], "types": [], "total_sequences": total}
def test_load_dynamic_with_keys(self): def test_load_dynamic_with_keys(self):
data = {"prompt": "hello", "seed": 42, "cfg": 1.5} data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value=data): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value=data):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys="prompt,seed,cfg" "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys="prompt,seed,cfg"
assert result[0] == "hello" )
assert result[1] == 42 assert result[0] == 5 # total_sequences
assert result[2] == 1.5 assert result[1] == "hello"
assert len(result) == MAX_DYNAMIC_OUTPUTS assert result[2] == 42
assert result[3] == 1.5
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
def test_load_dynamic_with_json_encoded_keys(self): def test_load_dynamic_with_json_encoded_keys(self):
"""JSON-encoded output_keys should be parsed correctly.""" """JSON-encoded output_keys should be parsed correctly."""
@@ -92,13 +106,14 @@ class TestProjectLoaderDynamic:
data = {"my,key": "comma_val", "normal": "ok"} data = {"my,key": "comma_val", "normal": "ok"}
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
keys_json = _json.dumps(["my,key", "normal"]) keys_json = _json.dumps(["my,key", "normal"])
with patch("project_loader._fetch_data", return_value=data): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value=data):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys=keys_json "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys=keys_json
assert result[0] == "comma_val" )
assert result[1] == "ok" assert result[1] == "comma_val"
assert result[2] == "ok"
def test_load_dynamic_type_coercion(self): def test_load_dynamic_type_coercion(self):
"""output_types should coerce values to declared types.""" """output_types should coerce values to declared types."""
@@ -107,41 +122,75 @@ class TestProjectLoaderDynamic:
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
keys_json = _json.dumps(["seed", "cfg", "prompt"]) keys_json = _json.dumps(["seed", "cfg", "prompt"])
types_json = _json.dumps(["INT", "FLOAT", "STRING"]) types_json = _json.dumps(["INT", "FLOAT", "STRING"])
with patch("project_loader._fetch_data", return_value=data): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value=data):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys=keys_json, output_types=types_json "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys=keys_json, output_types=types_json
assert result[0] == 42 # string "42" coerced to int )
assert result[1] == 1.5 # string "1.5" coerced to float assert result[1] == 42 # string "42" coerced to int
assert result[2] == "hello" # string stays string assert result[2] == 1.5 # string "1.5" coerced to float
assert result[3] == "hello" # string stays string
def test_load_dynamic_empty_keys(self): def test_load_dynamic_empty_keys(self):
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys="" "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys=""
assert all(v == "" for v in result) )
# Slot 0 is total_sequences (INT), rest are empty strings
assert result[0] == 5
assert all(v == "" for v in result[1:])
def test_load_dynamic_missing_key(self): def test_load_dynamic_missing_key(self):
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys="nonexistent" "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys="nonexistent"
assert result[0] == "" )
assert result[1] == ""
def test_load_dynamic_bool_becomes_string(self): def test_load_dynamic_bool_becomes_string(self):
node = ProjectLoaderDynamic() node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"flag": True}): with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
result = node.load_dynamic( with patch("project_loader._fetch_data", return_value={"flag": True}):
"http://localhost:8080", "proj1", "batch_i2v", 1, result = node.load_dynamic(
output_keys="flag" "http://localhost:8080", "proj1", "batch_i2v", 1,
) output_keys="flag"
assert result[0] == "true" )
assert result[1] == "true"
def test_load_dynamic_returns_total_sequences(self):
"""total_sequences should be the first output from keys metadata."""
node = ProjectLoaderDynamic()
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
with patch("project_loader._fetch_data", return_value={}):
result = node.load_dynamic(
"http://localhost:8080", "proj1", "batch_i2v", 1,
output_keys=""
)
assert result[0] == 42
def test_load_dynamic_raises_on_network_error(self):
"""Network errors from _fetch_keys should raise RuntimeError."""
node = ProjectLoaderDynamic()
error_resp = {"error": "network_error", "message": "Connection refused"}
with patch("project_loader._fetch_keys", return_value=error_resp):
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
def test_load_dynamic_raises_on_data_fetch_error(self):
"""Network errors from _fetch_data should raise RuntimeError."""
node = ProjectLoaderDynamic()
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
with patch("project_loader._fetch_data", return_value=error_resp):
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
def test_input_types_has_manager_url(self): def test_input_types_has_manager_url(self):
inputs = ProjectLoaderDynamic.INPUT_TYPES() inputs = ProjectLoaderDynamic.INPUT_TYPES()
@@ -154,88 +203,9 @@ class TestProjectLoaderDynamic:
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
class TestProjectLoaderStandard:
def test_load_standard(self):
data = {
"general_prompt": "hello",
"general_negative": "bad",
"current_prompt": "specific",
"negative": "neg",
"camera": "pan",
"flf": 0.5,
"seed": 42,
"video file path": "/v.mp4",
"reference image path": "/r.png",
"flf image path": "/f.png",
}
node = ProjectLoaderStandard()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png")
def test_load_standard_defaults(self):
node = ProjectLoaderStandard()
with patch("project_loader._fetch_data", return_value={}):
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
assert result[0] == "" # general_prompt
assert result[5] == 0.0 # flf
assert result[6] == 0 # seed
class TestProjectLoaderVACE:
def test_load_vace(self):
data = {
"general_prompt": "hello",
"general_negative": "bad",
"current_prompt": "specific",
"negative": "neg",
"camera": "pan",
"flf": 0.5,
"seed": 42,
"frame_to_skip": 81,
"input_a_frames": 16,
"input_b_frames": 16,
"reference path": "/ref",
"reference switch": 1,
"vace schedule": 2,
"video file path": "/v.mp4",
"reference image path": "/r.png",
}
node = ProjectLoaderVACE()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_vace("http://localhost:8080", "proj1", "batch", 1)
assert result[7] == 81 # frame_to_skip
assert result[12] == 2 # vace_schedule
class TestProjectLoaderLoRA:
def test_load_loras(self):
data = {
"lora 1 high": "<lora:model1:1.0>",
"lora 1 low": "<lora:model1:0.5>",
"lora 2 high": "",
"lora 2 low": "",
"lora 3 high": "",
"lora 3 low": "",
}
node = ProjectLoaderLoRA()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
assert result[0] == "<lora:model1:1.0>"
assert result[1] == "<lora:model1:0.5>"
def test_load_loras_empty(self):
node = ProjectLoaderLoRA()
with patch("project_loader._fetch_data", return_value={}):
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
assert all(v == "" for v in result)
class TestNodeMappings: class TestNodeMappings:
def test_mappings_exist(self): def test_mappings_exist(self):
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4

View File

@@ -32,18 +32,55 @@ app.registerExtension({
this.refreshDynamicOutputs(); this.refreshDynamicOutputs();
}); });
// Auto-refresh with 500ms debounce on widget changes
this._refreshTimer = null;
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
for (const widgetName of autoRefreshWidgets) {
const w = this.widgets?.find(w => w.name === widgetName);
if (w) {
const origCallback = w.callback;
const node = this;
w.callback = function (...args) {
origCallback?.apply(this, args);
clearTimeout(node._refreshTimer);
node._refreshTimer = setTimeout(() => {
node.refreshDynamicOutputs();
}, 500);
};
}
}
queueMicrotask(() => { queueMicrotask(() => {
if (!this._configured) { if (!this._configured) {
// New node (not loading) — remove the 32 Python default outputs // New node (not loading) — remove the Python default outputs
// and add only the fixed total_sequences slot
while (this.outputs.length > 0) { while (this.outputs.length > 0) {
this.removeOutput(0); this.removeOutput(0);
} }
this.addOutput("total_sequences", "INT");
this.setSize(this.computeSize()); this.setSize(this.computeSize());
app.graph?.setDirtyCanvas(true, true); app.graph?.setDirtyCanvas(true, true);
} }
}); });
}; };
nodeType.prototype._setStatus = function (status, message) {
const baseTitle = "Project Loader (Dynamic)";
if (status === "ok") {
this.title = baseTitle;
this.color = undefined;
this.bgcolor = undefined;
} else if (status === "error") {
this.title = baseTitle + " - ERROR";
this.color = "#ff4444";
this.bgcolor = "#331111";
if (message) this.title = baseTitle + ": " + message;
} else if (status === "loading") {
this.title = baseTitle + " - Loading...";
}
app.graph?.setDirtyCanvas(true, true);
};
nodeType.prototype.refreshDynamicOutputs = async function () { nodeType.prototype.refreshDynamicOutputs = async function () {
const urlWidget = this.widgets?.find(w => w.name === "manager_url"); const urlWidget = this.widgets?.find(w => w.name === "manager_url");
const projectWidget = this.widgets?.find(w => w.name === "project_name"); const projectWidget = this.widgets?.find(w => w.name === "project_name");
@@ -52,13 +89,20 @@ app.registerExtension({
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
this._setStatus("loading");
try { try {
const resp = await api.fetchApi( const resp = await api.fetchApi(
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
); );
if (!resp.ok) { if (!resp.ok) {
console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs"); let errorMsg = `HTTP ${resp.status}`;
try {
const errData = await resp.json();
if (errData.message) errorMsg = errData.message;
} catch (_) {}
this._setStatus("error", errorMsg);
return; return;
} }
@@ -68,7 +112,8 @@ app.registerExtension({
// If the API returned an error or missing data, keep existing outputs and links intact // If the API returned an error or missing data, keep existing outputs and links intact
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types"); const errMsg = data.error ? data.message || data.error : "Missing keys/types";
this._setStatus("error", errMsg);
return; return;
} }
@@ -78,14 +123,20 @@ app.registerExtension({
const otWidget = this.widgets?.find(w => w.name === "output_types"); const otWidget = this.widgets?.find(w => w.name === "output_types");
if (otWidget) otWidget.value = JSON.stringify(types); if (otWidget) otWidget.value = JSON.stringify(types);
// Build a map of current output names to slot indices // Slot 0 is always total_sequences (INT) — ensure it exists
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
}
this.outputs[0].type = "INT";
// Build a map of current dynamic output names to slot indices (skip slot 0)
const oldSlots = {}; const oldSlots = {};
for (let i = 0; i < this.outputs.length; i++) { for (let i = 1; i < this.outputs.length; i++) {
oldSlots[this.outputs[i].name] = i; oldSlots[this.outputs[i].name] = i;
} }
// Build new outputs, reusing existing slots to preserve links // Build new dynamic outputs, reusing existing slots to preserve links
const newOutputs = []; const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
for (let k = 0; k < keys.length; k++) { for (let k = 0; k < keys.length; k++) {
const key = keys[k]; const key = keys[k];
const type = types[k] || "*"; const type = types[k] || "*";
@@ -122,10 +173,12 @@ app.registerExtension({
} }
} }
this._setStatus("ok");
this.setSize(this.computeSize()); this.setSize(this.computeSize());
app.graph?.setDirtyCanvas(true, true); app.graph?.setDirtyCanvas(true, true);
} catch (e) { } catch (e) {
console.error("[ProjectLoaderDynamic] Refresh failed:", e); console.error("[ProjectLoaderDynamic] Refresh failed:", e);
this._setStatus("error", "Server unreachable");
} }
}; };
@@ -158,23 +211,59 @@ app.registerExtension({
} }
} }
// Ensure slot 0 is total_sequences (INT)
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
// LiteGraph restores links AFTER onConfigure, so graph.links is
// empty here. Defer link fixup to a microtask that runs after the
// synchronous graph.configure() finishes (including link restoration).
// We must also rebuild output.links arrays because LiteGraph will
// place link IDs on the wrong outputs (shifted by the unshift above).
const node = this;
queueMicrotask(() => {
if (!node.graph) return;
// Clear all output.links — they were populated at old indices
for (const output of node.outputs) {
output.links = null;
}
// Rebuild from graph.links with corrected origin_slot (+1)
for (const linkId in node.graph.links) {
const link = node.graph.links[linkId];
if (!link || link.origin_id !== node.id) continue;
link.origin_slot += 1;
const output = node.outputs[link.origin_slot];
if (output) {
if (!output.links) output.links = [];
output.links.push(link.id);
}
}
app.graph?.setDirtyCanvas(true, true);
});
}
this.outputs[0].type = "INT";
this.outputs[0].name = "total_sequences";
if (keys.length > 0) { if (keys.length > 0) {
// On load, LiteGraph already restored serialized outputs with links. // On load, LiteGraph already restored serialized outputs with links.
// Rename and set types to match stored state (preserves links). // Dynamic outputs start at slot 1. Rename and set types to match stored state.
for (let i = 0; i < this.outputs.length && i < keys.length; i++) { for (let i = 0; i < keys.length; i++) {
this.outputs[i].name = keys[i]; const slotIdx = i + 1; // offset by 1 for total_sequences
if (types[i]) this.outputs[i].type = types[i]; if (slotIdx < this.outputs.length) {
this.outputs[slotIdx].name = keys[i];
if (types[i]) this.outputs[slotIdx].type = types[i];
}
} }
// Remove any extra outputs beyond the key count // Remove any extra outputs beyond keys + total_sequences
while (this.outputs.length > keys.length) { while (this.outputs.length > keys.length + 1) {
this.removeOutput(this.outputs.length - 1); this.removeOutput(this.outputs.length - 1);
} }
} else if (this.outputs.length > 0) { } else if (this.outputs.length > 1) {
// Widget values empty but serialized outputs exist — sync widgets // Widget values empty but serialized dynamic outputs exist — sync widgets
// from the outputs LiteGraph already restored (fallback). // from the outputs LiteGraph already restored (fallback, skip slot 0).
if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name)); const dynamicOutputs = this.outputs.slice(1);
if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type)); if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
} }
this.setSize(this.computeSize()); this.setSize(this.computeSize());