Improve ProjectLoaderDynamic UX: single node, error feedback, auto-refresh
Remove 3 redundant hardcoded nodes (Standard/VACE/LoRA), keeping only the Dynamic node. Add total_sequences INT output (slot 0) for loop counting. Add structured error handling: _fetch_json returns typed error dicts, load_dynamic raises RuntimeError with descriptive messages, JS shows red border/title on errors. Add 500ms debounced auto-refresh on widget changes. Add 404s for missing project/file in API endpoints. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -55,13 +55,26 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
||||
|
||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
data = db.query_sequence_data(name, file_name, seq)
|
||||
proj = db.get_project(name)
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
df = db.get_data_file_by_names(name, file_name)
|
||||
if not df:
|
||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||
data = db.get_sequence(df["id"], seq)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail="Sequence not found")
|
||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||
return data
|
||||
|
||||
|
||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
keys, types = db.query_sequence_keys(name, file_name, seq)
|
||||
return {"keys": keys, "types": types}
|
||||
proj = db.get_project(name)
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
df = db.get_data_file_by_names(name, file_name)
|
||||
if not df:
|
||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||
keys, types = db.get_sequence_keys(df["id"], seq)
|
||||
total = db.count_sequences(df["id"])
|
||||
return {"keys": keys, "types": types, "total_sequences": total}
|
||||
|
||||
15
db.py
15
db.py
@@ -182,6 +182,21 @@ class ProjectDB:
|
||||
).fetchall()
|
||||
return [r["sequence_number"] for r in rows]
|
||||
|
||||
def count_sequences(self, data_file_id: int) -> int:
|
||||
"""Return the number of sequences for a data file."""
|
||||
row = self.conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
||||
(data_file_id,),
|
||||
).fetchone()
|
||||
return row["cnt"]
|
||||
|
||||
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
||||
"""Return total sequence count by project and file names."""
|
||||
df = self.get_data_file_by_names(project_name, file_name)
|
||||
if not df:
|
||||
return 0
|
||||
return self.count_sequences(df["id"])
|
||||
|
||||
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||
"""Returns (keys, types) for a sequence's data dict."""
|
||||
data = self.get_sequence(data_file_id, sequence_number)
|
||||
|
||||
@@ -39,13 +39,31 @@ def to_int(val: Any) -> int:
|
||||
|
||||
|
||||
def _fetch_json(url: str) -> dict:
|
||||
"""Fetch JSON from a URL using stdlib urllib."""
|
||||
"""Fetch JSON from a URL using stdlib urllib.
|
||||
|
||||
On error, returns a dict with an "error" key describing the failure.
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=5) as resp:
|
||||
return json.loads(resp.read())
|
||||
except (urllib.error.URLError, json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"Failed to fetch {url}: {e}")
|
||||
return {}
|
||||
except urllib.error.HTTPError as e:
|
||||
# HTTPError is a subclass of URLError — must be caught first
|
||||
body = ""
|
||||
try:
|
||||
raw = e.read()
|
||||
detail = json.loads(raw)
|
||||
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
|
||||
except Exception:
|
||||
body = str(e)
|
||||
logger.warning(f"HTTP {e.code} from {url}: {body}")
|
||||
return {"error": "http_error", "status": e.code, "message": body}
|
||||
except (urllib.error.URLError, OSError) as e:
|
||||
reason = str(e.reason) if hasattr(e, "reason") else str(e)
|
||||
logger.warning(f"Network error fetching {url}: {reason}")
|
||||
return {"error": "network_error", "message": reason}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON from {url}: {e}")
|
||||
return {"error": "parse_error", "message": str(e)}
|
||||
|
||||
|
||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||
@@ -100,6 +118,9 @@ if PromptServer is not None:
|
||||
except (ValueError, TypeError):
|
||||
seq = 1
|
||||
data = _fetch_keys(manager_url, project, file_name, seq)
|
||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
status = data.get("status", 502)
|
||||
return web.json_response(data, status=status)
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
@@ -124,15 +145,25 @@ class ProjectLoaderDynamic:
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
||||
RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
||||
RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
||||
RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
||||
FUNCTION = "load_dynamic"
|
||||
CATEGORY = "utils/json/project"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
||||
output_keys="", output_types=""):
|
||||
# Fetch keys metadata (includes total_sequences count)
|
||||
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
||||
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
msg = keys_meta.get("message", "Unknown error")
|
||||
raise RuntimeError(f"Failed to fetch project keys: {msg}")
|
||||
total_sequences = keys_meta.get("total_sequences", 0)
|
||||
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
msg = data.get("message", "Unknown error")
|
||||
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
|
||||
|
||||
# Parse keys — try JSON array first, fall back to comma-split for compat
|
||||
keys = []
|
||||
@@ -171,111 +202,14 @@ class ProjectLoaderDynamic:
|
||||
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
||||
results.append("")
|
||||
|
||||
return tuple(results)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 1. STANDARD NODE (Project-based I2V)
|
||||
# ==========================================
|
||||
|
||||
class ProjectLoaderStandard:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
||||
FUNCTION = "load_standard"
|
||||
CATEGORY = "utils/json/project"
|
||||
|
||||
def load_standard(self, manager_url, project_name, file_name, sequence_number):
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
return (
|
||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
||||
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
|
||||
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 2. VACE NODE (Project-based)
|
||||
# ==========================================
|
||||
|
||||
class ProjectLoaderVACE:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
||||
FUNCTION = "load_vace"
|
||||
CATEGORY = "utils/json/project"
|
||||
|
||||
def load_vace(self, manager_url, project_name, file_name, sequence_number):
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
return (
|
||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
||||
to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)),
|
||||
to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)),
|
||||
str(data.get("reference path", "")), to_int(data.get("reference switch", 1)),
|
||||
to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")),
|
||||
str(data.get("reference image path", ""))
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 3. LoRA NODE (Project-based)
|
||||
# ==========================================
|
||||
|
||||
class ProjectLoaderLoRA:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
||||
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
||||
FUNCTION = "load_loras"
|
||||
CATEGORY = "utils/json/project"
|
||||
|
||||
def load_loras(self, manager_url, project_name, file_name, sequence_number):
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
return (
|
||||
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
|
||||
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
|
||||
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
|
||||
)
|
||||
return (total_sequences,) + tuple(results)
|
||||
|
||||
|
||||
# --- Mappings ---
|
||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||
"ProjectLoaderStandard": ProjectLoaderStandard,
|
||||
"ProjectLoaderVACE": ProjectLoaderVACE,
|
||||
"ProjectLoaderLoRA": ProjectLoaderLoRA,
|
||||
}
|
||||
|
||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||
"ProjectLoaderStandard": "Project Loader (Standard/I2V)",
|
||||
"ProjectLoaderVACE": "Project Loader (VACE Full)",
|
||||
"ProjectLoaderLoRA": "Project Loader (LoRAs)",
|
||||
}
|
||||
|
||||
@@ -172,6 +172,25 @@ class TestSequences:
|
||||
db.delete_sequences_for_file(df_id)
|
||||
assert db.list_sequences(df_id) == []
|
||||
|
||||
def test_count_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
assert db.count_sequences(df_id) == 0
|
||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||
db.upsert_sequence(df_id, 3, {"c": 3})
|
||||
assert db.count_sequences(df_id) == 3
|
||||
|
||||
def test_query_total_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||
assert db.query_total_sequences("p1", "batch") == 2
|
||||
|
||||
def test_query_total_sequences_nonexistent(self, db):
|
||||
assert db.query_total_sequences("nope", "nope") == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# History trees
|
||||
|
||||
@@ -6,9 +6,6 @@ import pytest
|
||||
|
||||
from project_loader import (
|
||||
ProjectLoaderDynamic,
|
||||
ProjectLoaderStandard,
|
||||
ProjectLoaderVACE,
|
||||
ProjectLoaderLoRA,
|
||||
_fetch_json,
|
||||
_fetch_data,
|
||||
_fetch_keys,
|
||||
@@ -32,11 +29,23 @@ class TestFetchHelpers:
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result == data
|
||||
|
||||
def test_fetch_json_failure(self):
|
||||
import urllib.error
|
||||
def test_fetch_json_network_error(self):
|
||||
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result == {}
|
||||
assert result["error"] == "network_error"
|
||||
assert "connection refused" in result["message"]
|
||||
|
||||
def test_fetch_json_http_error(self):
|
||||
import urllib.error
|
||||
err = urllib.error.HTTPError(
|
||||
"http://example.com/api", 404, "Not Found", {},
|
||||
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
|
||||
)
|
||||
with patch("project_loader.urllib.request.urlopen", side_effect=err):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result["error"] == "http_error"
|
||||
assert result["status"] == 404
|
||||
assert "not found" in result["message"].lower()
|
||||
|
||||
def test_fetch_data_builds_url(self):
|
||||
data = {"prompt": "hello"}
|
||||
@@ -73,18 +82,23 @@ class TestFetchHelpers:
|
||||
|
||||
|
||||
class TestProjectLoaderDynamic:
|
||||
def _keys_meta(self, total=5):
|
||||
return {"keys": [], "types": [], "total_sequences": total}
|
||||
|
||||
def test_load_dynamic_with_keys(self):
|
||||
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="prompt,seed,cfg"
|
||||
)
|
||||
assert result[0] == "hello"
|
||||
assert result[1] == 42
|
||||
assert result[2] == 1.5
|
||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
||||
assert result[0] == 5 # total_sequences
|
||||
assert result[1] == "hello"
|
||||
assert result[2] == 42
|
||||
assert result[3] == 1.5
|
||||
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
|
||||
|
||||
def test_load_dynamic_with_json_encoded_keys(self):
|
||||
"""JSON-encoded output_keys should be parsed correctly."""
|
||||
@@ -92,13 +106,14 @@ class TestProjectLoaderDynamic:
|
||||
data = {"my,key": "comma_val", "normal": "ok"}
|
||||
node = ProjectLoaderDynamic()
|
||||
keys_json = _json.dumps(["my,key", "normal"])
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=keys_json
|
||||
)
|
||||
assert result[0] == "comma_val"
|
||||
assert result[1] == "ok"
|
||||
assert result[1] == "comma_val"
|
||||
assert result[2] == "ok"
|
||||
|
||||
def test_load_dynamic_type_coercion(self):
|
||||
"""output_types should coerce values to declared types."""
|
||||
@@ -107,41 +122,75 @@ class TestProjectLoaderDynamic:
|
||||
node = ProjectLoaderDynamic()
|
||||
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
||||
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=keys_json, output_types=types_json
|
||||
)
|
||||
assert result[0] == 42 # string "42" coerced to int
|
||||
assert result[1] == 1.5 # string "1.5" coerced to float
|
||||
assert result[2] == "hello" # string stays string
|
||||
assert result[1] == 42 # string "42" coerced to int
|
||||
assert result[2] == 1.5 # string "1.5" coerced to float
|
||||
assert result[3] == "hello" # string stays string
|
||||
|
||||
def test_load_dynamic_empty_keys(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=""
|
||||
)
|
||||
assert all(v == "" for v in result)
|
||||
# Slot 0 is total_sequences (INT), rest are empty strings
|
||||
assert result[0] == 5
|
||||
assert all(v == "" for v in result[1:])
|
||||
|
||||
def test_load_dynamic_missing_key(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="nonexistent"
|
||||
)
|
||||
assert result[0] == ""
|
||||
assert result[1] == ""
|
||||
|
||||
def test_load_dynamic_bool_becomes_string(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="flag"
|
||||
)
|
||||
assert result[0] == "true"
|
||||
assert result[1] == "true"
|
||||
|
||||
def test_load_dynamic_returns_total_sequences(self):
|
||||
"""total_sequences should be the first output from keys metadata."""
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=""
|
||||
)
|
||||
assert result[0] == 42
|
||||
|
||||
def test_load_dynamic_raises_on_network_error(self):
|
||||
"""Network errors from _fetch_keys should raise RuntimeError."""
|
||||
node = ProjectLoaderDynamic()
|
||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||
with patch("project_loader._fetch_keys", return_value=error_resp):
|
||||
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
|
||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||
|
||||
def test_load_dynamic_raises_on_data_fetch_error(self):
|
||||
"""Network errors from _fetch_data should raise RuntimeError."""
|
||||
node = ProjectLoaderDynamic()
|
||||
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
|
||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||
|
||||
def test_input_types_has_manager_url(self):
|
||||
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
||||
@@ -154,88 +203,9 @@ class TestProjectLoaderDynamic:
|
||||
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||
|
||||
|
||||
class TestProjectLoaderStandard:
|
||||
def test_load_standard(self):
|
||||
data = {
|
||||
"general_prompt": "hello",
|
||||
"general_negative": "bad",
|
||||
"current_prompt": "specific",
|
||||
"negative": "neg",
|
||||
"camera": "pan",
|
||||
"flf": 0.5,
|
||||
"seed": 42,
|
||||
"video file path": "/v.mp4",
|
||||
"reference image path": "/r.png",
|
||||
"flf image path": "/f.png",
|
||||
}
|
||||
node = ProjectLoaderStandard()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png")
|
||||
|
||||
def test_load_standard_defaults(self):
|
||||
node = ProjectLoaderStandard()
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[0] == "" # general_prompt
|
||||
assert result[5] == 0.0 # flf
|
||||
assert result[6] == 0 # seed
|
||||
|
||||
|
||||
class TestProjectLoaderVACE:
|
||||
def test_load_vace(self):
|
||||
data = {
|
||||
"general_prompt": "hello",
|
||||
"general_negative": "bad",
|
||||
"current_prompt": "specific",
|
||||
"negative": "neg",
|
||||
"camera": "pan",
|
||||
"flf": 0.5,
|
||||
"seed": 42,
|
||||
"frame_to_skip": 81,
|
||||
"input_a_frames": 16,
|
||||
"input_b_frames": 16,
|
||||
"reference path": "/ref",
|
||||
"reference switch": 1,
|
||||
"vace schedule": 2,
|
||||
"video file path": "/v.mp4",
|
||||
"reference image path": "/r.png",
|
||||
}
|
||||
node = ProjectLoaderVACE()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_vace("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[7] == 81 # frame_to_skip
|
||||
assert result[12] == 2 # vace_schedule
|
||||
|
||||
|
||||
class TestProjectLoaderLoRA:
|
||||
def test_load_loras(self):
|
||||
data = {
|
||||
"lora 1 high": "<lora:model1:1.0>",
|
||||
"lora 1 low": "<lora:model1:0.5>",
|
||||
"lora 2 high": "",
|
||||
"lora 2 low": "",
|
||||
"lora 3 high": "",
|
||||
"lora 3 low": "",
|
||||
}
|
||||
node = ProjectLoaderLoRA()
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert result[0] == "<lora:model1:1.0>"
|
||||
assert result[1] == "<lora:model1:0.5>"
|
||||
|
||||
def test_load_loras_empty(self):
|
||||
node = ProjectLoaderLoRA()
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
||||
assert all(v == "" for v in result)
|
||||
|
||||
|
||||
class TestNodeMappings:
|
||||
def test_mappings_exist(self):
|
||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4
|
||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
|
||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
|
||||
|
||||
@@ -32,18 +32,55 @@ app.registerExtension({
|
||||
this.refreshDynamicOutputs();
|
||||
});
|
||||
|
||||
// Auto-refresh with 500ms debounce on widget changes
|
||||
this._refreshTimer = null;
|
||||
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
|
||||
for (const widgetName of autoRefreshWidgets) {
|
||||
const w = this.widgets?.find(w => w.name === widgetName);
|
||||
if (w) {
|
||||
const origCallback = w.callback;
|
||||
const node = this;
|
||||
w.callback = function (...args) {
|
||||
origCallback?.apply(this, args);
|
||||
clearTimeout(node._refreshTimer);
|
||||
node._refreshTimer = setTimeout(() => {
|
||||
node.refreshDynamicOutputs();
|
||||
}, 500);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
queueMicrotask(() => {
|
||||
if (!this._configured) {
|
||||
// New node (not loading) — remove the 32 Python default outputs
|
||||
// New node (not loading) — remove the Python default outputs
|
||||
// and add only the fixed total_sequences slot
|
||||
while (this.outputs.length > 0) {
|
||||
this.removeOutput(0);
|
||||
}
|
||||
this.addOutput("total_sequences", "INT");
|
||||
this.setSize(this.computeSize());
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
nodeType.prototype._setStatus = function (status, message) {
|
||||
const baseTitle = "Project Loader (Dynamic)";
|
||||
if (status === "ok") {
|
||||
this.title = baseTitle;
|
||||
this.color = undefined;
|
||||
this.bgcolor = undefined;
|
||||
} else if (status === "error") {
|
||||
this.title = baseTitle + " - ERROR";
|
||||
this.color = "#ff4444";
|
||||
this.bgcolor = "#331111";
|
||||
if (message) this.title = baseTitle + ": " + message;
|
||||
} else if (status === "loading") {
|
||||
this.title = baseTitle + " - Loading...";
|
||||
}
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
};
|
||||
|
||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
||||
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
||||
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
||||
@@ -52,13 +89,20 @@ app.registerExtension({
|
||||
|
||||
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
||||
|
||||
this._setStatus("loading");
|
||||
|
||||
try {
|
||||
const resp = await api.fetchApi(
|
||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
||||
);
|
||||
|
||||
if (!resp.ok) {
|
||||
console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs");
|
||||
let errorMsg = `HTTP ${resp.status}`;
|
||||
try {
|
||||
const errData = await resp.json();
|
||||
if (errData.message) errorMsg = errData.message;
|
||||
} catch (_) {}
|
||||
this._setStatus("error", errorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -68,7 +112,8 @@ app.registerExtension({
|
||||
|
||||
// If the API returned an error or missing data, keep existing outputs and links intact
|
||||
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
||||
console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types");
|
||||
const errMsg = data.error ? data.message || data.error : "Missing keys/types";
|
||||
this._setStatus("error", errMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -78,14 +123,20 @@ app.registerExtension({
|
||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||
if (otWidget) otWidget.value = JSON.stringify(types);
|
||||
|
||||
// Build a map of current output names to slot indices
|
||||
// Slot 0 is always total_sequences (INT) — ensure it exists
|
||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||
}
|
||||
this.outputs[0].type = "INT";
|
||||
|
||||
// Build a map of current dynamic output names to slot indices (skip slot 0)
|
||||
const oldSlots = {};
|
||||
for (let i = 0; i < this.outputs.length; i++) {
|
||||
for (let i = 1; i < this.outputs.length; i++) {
|
||||
oldSlots[this.outputs[i].name] = i;
|
||||
}
|
||||
|
||||
// Build new outputs, reusing existing slots to preserve links
|
||||
const newOutputs = [];
|
||||
// Build new dynamic outputs, reusing existing slots to preserve links
|
||||
const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
|
||||
for (let k = 0; k < keys.length; k++) {
|
||||
const key = keys[k];
|
||||
const type = types[k] || "*";
|
||||
@@ -122,10 +173,12 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
this._setStatus("ok");
|
||||
this.setSize(this.computeSize());
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
} catch (e) {
|
||||
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
||||
this._setStatus("error", "Server unreachable");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -158,23 +211,59 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure slot 0 is total_sequences (INT)
|
||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||
// LiteGraph restores links AFTER onConfigure, so graph.links is
|
||||
// empty here. Defer link fixup to a microtask that runs after the
|
||||
// synchronous graph.configure() finishes (including link restoration).
|
||||
// We must also rebuild output.links arrays because LiteGraph will
|
||||
// place link IDs on the wrong outputs (shifted by the unshift above).
|
||||
const node = this;
|
||||
queueMicrotask(() => {
|
||||
if (!node.graph) return;
|
||||
// Clear all output.links — they were populated at old indices
|
||||
for (const output of node.outputs) {
|
||||
output.links = null;
|
||||
}
|
||||
// Rebuild from graph.links with corrected origin_slot (+1)
|
||||
for (const linkId in node.graph.links) {
|
||||
const link = node.graph.links[linkId];
|
||||
if (!link || link.origin_id !== node.id) continue;
|
||||
link.origin_slot += 1;
|
||||
const output = node.outputs[link.origin_slot];
|
||||
if (output) {
|
||||
if (!output.links) output.links = [];
|
||||
output.links.push(link.id);
|
||||
}
|
||||
}
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
});
|
||||
}
|
||||
this.outputs[0].type = "INT";
|
||||
this.outputs[0].name = "total_sequences";
|
||||
|
||||
if (keys.length > 0) {
|
||||
// On load, LiteGraph already restored serialized outputs with links.
|
||||
// Rename and set types to match stored state (preserves links).
|
||||
for (let i = 0; i < this.outputs.length && i < keys.length; i++) {
|
||||
this.outputs[i].name = keys[i];
|
||||
if (types[i]) this.outputs[i].type = types[i];
|
||||
// Dynamic outputs start at slot 1. Rename and set types to match stored state.
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
const slotIdx = i + 1; // offset by 1 for total_sequences
|
||||
if (slotIdx < this.outputs.length) {
|
||||
this.outputs[slotIdx].name = keys[i];
|
||||
if (types[i]) this.outputs[slotIdx].type = types[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any extra outputs beyond the key count
|
||||
while (this.outputs.length > keys.length) {
|
||||
// Remove any extra outputs beyond keys + total_sequences
|
||||
while (this.outputs.length > keys.length + 1) {
|
||||
this.removeOutput(this.outputs.length - 1);
|
||||
}
|
||||
} else if (this.outputs.length > 0) {
|
||||
// Widget values empty but serialized outputs exist — sync widgets
|
||||
// from the outputs LiteGraph already restored (fallback).
|
||||
if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name));
|
||||
if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type));
|
||||
} else if (this.outputs.length > 1) {
|
||||
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
||||
// from the outputs LiteGraph already restored (fallback, skip slot 0).
|
||||
const dynamicOutputs = this.outputs.slice(1);
|
||||
if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
|
||||
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
|
||||
}
|
||||
|
||||
this.setSize(this.computeSize());
|
||||
|
||||
Reference in New Issue
Block a user