Improve ProjectLoaderDynamic UX: single node, error feedback, auto-refresh
Remove 3 redundant hardcoded nodes (Standard/VACE/LoRA), keeping only the Dynamic node. Add total_sequences INT output (slot 0) for loop counting. Add structured error handling: _fetch_json returns typed error dicts, load_dynamic raises RuntimeError with descriptive messages, JS shows red border/title on errors. Add 500ms debounced auto-refresh on widget changes. Add 404s for missing project/file in API endpoints. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -55,13 +55,26 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
|||||||
|
|
||||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
data = db.query_sequence_data(name, file_name, seq)
|
proj = db.get_project(name)
|
||||||
|
if not proj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||||
|
df = db.get_data_file_by_names(name, file_name)
|
||||||
|
if not df:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
|
data = db.get_sequence(df["id"], seq)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(status_code=404, detail="Sequence not found")
|
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
keys, types = db.query_sequence_keys(name, file_name, seq)
|
proj = db.get_project(name)
|
||||||
return {"keys": keys, "types": types}
|
if not proj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||||
|
df = db.get_data_file_by_names(name, file_name)
|
||||||
|
if not df:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
|
keys, types = db.get_sequence_keys(df["id"], seq)
|
||||||
|
total = db.count_sequences(df["id"])
|
||||||
|
return {"keys": keys, "types": types, "total_sequences": total}
|
||||||
|
|||||||
15
db.py
15
db.py
@@ -182,6 +182,21 @@ class ProjectDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return [r["sequence_number"] for r in rows]
|
return [r["sequence_number"] for r in rows]
|
||||||
|
|
||||||
|
def count_sequences(self, data_file_id: int) -> int:
|
||||||
|
"""Return the number of sequences for a data file."""
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
||||||
|
(data_file_id,),
|
||||||
|
).fetchone()
|
||||||
|
return row["cnt"]
|
||||||
|
|
||||||
|
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
||||||
|
"""Return total sequence count by project and file names."""
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return 0
|
||||||
|
return self.count_sequences(df["id"])
|
||||||
|
|
||||||
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||||
"""Returns (keys, types) for a sequence's data dict."""
|
"""Returns (keys, types) for a sequence's data dict."""
|
||||||
data = self.get_sequence(data_file_id, sequence_number)
|
data = self.get_sequence(data_file_id, sequence_number)
|
||||||
|
|||||||
@@ -39,13 +39,31 @@ def to_int(val: Any) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def _fetch_json(url: str) -> dict:
|
def _fetch_json(url: str) -> dict:
|
||||||
"""Fetch JSON from a URL using stdlib urllib."""
|
"""Fetch JSON from a URL using stdlib urllib.
|
||||||
|
|
||||||
|
On error, returns a dict with an "error" key describing the failure.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(url, timeout=5) as resp:
|
with urllib.request.urlopen(url, timeout=5) as resp:
|
||||||
return json.loads(resp.read())
|
return json.loads(resp.read())
|
||||||
except (urllib.error.URLError, json.JSONDecodeError, OSError) as e:
|
except urllib.error.HTTPError as e:
|
||||||
logger.warning(f"Failed to fetch {url}: {e}")
|
# HTTPError is a subclass of URLError — must be caught first
|
||||||
return {}
|
body = ""
|
||||||
|
try:
|
||||||
|
raw = e.read()
|
||||||
|
detail = json.loads(raw)
|
||||||
|
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
|
||||||
|
except Exception:
|
||||||
|
body = str(e)
|
||||||
|
logger.warning(f"HTTP {e.code} from {url}: {body}")
|
||||||
|
return {"error": "http_error", "status": e.code, "message": body}
|
||||||
|
except (urllib.error.URLError, OSError) as e:
|
||||||
|
reason = str(e.reason) if hasattr(e, "reason") else str(e)
|
||||||
|
logger.warning(f"Network error fetching {url}: {reason}")
|
||||||
|
return {"error": "network_error", "message": reason}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Invalid JSON from {url}: {e}")
|
||||||
|
return {"error": "parse_error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||||
@@ -100,6 +118,9 @@ if PromptServer is not None:
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
seq = 1
|
seq = 1
|
||||||
data = _fetch_keys(manager_url, project, file_name, seq)
|
data = _fetch_keys(manager_url, project, file_name, seq)
|
||||||
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
status = data.get("status", 502)
|
||||||
|
return web.json_response(data, status=status)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
|
|
||||||
@@ -124,15 +145,25 @@ class ProjectLoaderDynamic:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
||||||
RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
||||||
FUNCTION = "load_dynamic"
|
FUNCTION = "load_dynamic"
|
||||||
CATEGORY = "utils/json/project"
|
CATEGORY = "utils/json/project"
|
||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
||||||
output_keys="", output_types=""):
|
output_keys="", output_types=""):
|
||||||
|
# Fetch keys metadata (includes total_sequences count)
|
||||||
|
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
||||||
|
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
msg = keys_meta.get("message", "Unknown error")
|
||||||
|
raise RuntimeError(f"Failed to fetch project keys: {msg}")
|
||||||
|
total_sequences = keys_meta.get("total_sequences", 0)
|
||||||
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||||
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
msg = data.get("message", "Unknown error")
|
||||||
|
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
|
||||||
|
|
||||||
# Parse keys — try JSON array first, fall back to comma-split for compat
|
# Parse keys — try JSON array first, fall back to comma-split for compat
|
||||||
keys = []
|
keys = []
|
||||||
@@ -171,111 +202,14 @@ class ProjectLoaderDynamic:
|
|||||||
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
||||||
results.append("")
|
results.append("")
|
||||||
|
|
||||||
return tuple(results)
|
return (total_sequences,) + tuple(results)
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 1. STANDARD NODE (Project-based I2V)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class ProjectLoaderStandard:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
|
||||||
FUNCTION = "load_standard"
|
|
||||||
CATEGORY = "utils/json/project"
|
|
||||||
|
|
||||||
def load_standard(self, manager_url, project_name, file_name, sequence_number):
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
return (
|
|
||||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
|
||||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
|
||||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
|
||||||
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
|
|
||||||
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 2. VACE NODE (Project-based)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class ProjectLoaderVACE:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
|
||||||
FUNCTION = "load_vace"
|
|
||||||
CATEGORY = "utils/json/project"
|
|
||||||
|
|
||||||
def load_vace(self, manager_url, project_name, file_name, sequence_number):
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
return (
|
|
||||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
|
||||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
|
||||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
|
||||||
to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)),
|
|
||||||
to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)),
|
|
||||||
str(data.get("reference path", "")), to_int(data.get("reference switch", 1)),
|
|
||||||
to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")),
|
|
||||||
str(data.get("reference image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 3. LoRA NODE (Project-based)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class ProjectLoaderLoRA:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {
|
|
||||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
|
||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
|
||||||
FUNCTION = "load_loras"
|
|
||||||
CATEGORY = "utils/json/project"
|
|
||||||
|
|
||||||
def load_loras(self, manager_url, project_name, file_name, sequence_number):
|
|
||||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
|
||||||
return (
|
|
||||||
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
|
|
||||||
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
|
|
||||||
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Mappings ---
|
# --- Mappings ---
|
||||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||||
"ProjectLoaderStandard": ProjectLoaderStandard,
|
|
||||||
"ProjectLoaderVACE": ProjectLoaderVACE,
|
|
||||||
"ProjectLoaderLoRA": ProjectLoaderLoRA,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||||
"ProjectLoaderStandard": "Project Loader (Standard/I2V)",
|
|
||||||
"ProjectLoaderVACE": "Project Loader (VACE Full)",
|
|
||||||
"ProjectLoaderLoRA": "Project Loader (LoRAs)",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -172,6 +172,25 @@ class TestSequences:
|
|||||||
db.delete_sequences_for_file(df_id)
|
db.delete_sequences_for_file(df_id)
|
||||||
assert db.list_sequences(df_id) == []
|
assert db.list_sequences(df_id) == []
|
||||||
|
|
||||||
|
def test_count_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
assert db.count_sequences(df_id) == 0
|
||||||
|
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||||
|
db.upsert_sequence(df_id, 3, {"c": 3})
|
||||||
|
assert db.count_sequences(df_id) == 3
|
||||||
|
|
||||||
|
def test_query_total_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||||
|
assert db.query_total_sequences("p1", "batch") == 2
|
||||||
|
|
||||||
|
def test_query_total_sequences_nonexistent(self, db):
|
||||||
|
assert db.query_total_sequences("nope", "nope") == 0
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# History trees
|
# History trees
|
||||||
|
|||||||
@@ -6,9 +6,6 @@ import pytest
|
|||||||
|
|
||||||
from project_loader import (
|
from project_loader import (
|
||||||
ProjectLoaderDynamic,
|
ProjectLoaderDynamic,
|
||||||
ProjectLoaderStandard,
|
|
||||||
ProjectLoaderVACE,
|
|
||||||
ProjectLoaderLoRA,
|
|
||||||
_fetch_json,
|
_fetch_json,
|
||||||
_fetch_data,
|
_fetch_data,
|
||||||
_fetch_keys,
|
_fetch_keys,
|
||||||
@@ -32,11 +29,23 @@ class TestFetchHelpers:
|
|||||||
result = _fetch_json("http://example.com/api")
|
result = _fetch_json("http://example.com/api")
|
||||||
assert result == data
|
assert result == data
|
||||||
|
|
||||||
def test_fetch_json_failure(self):
|
def test_fetch_json_network_error(self):
|
||||||
import urllib.error
|
|
||||||
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
||||||
result = _fetch_json("http://example.com/api")
|
result = _fetch_json("http://example.com/api")
|
||||||
assert result == {}
|
assert result["error"] == "network_error"
|
||||||
|
assert "connection refused" in result["message"]
|
||||||
|
|
||||||
|
def test_fetch_json_http_error(self):
|
||||||
|
import urllib.error
|
||||||
|
err = urllib.error.HTTPError(
|
||||||
|
"http://example.com/api", 404, "Not Found", {},
|
||||||
|
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
|
||||||
|
)
|
||||||
|
with patch("project_loader.urllib.request.urlopen", side_effect=err):
|
||||||
|
result = _fetch_json("http://example.com/api")
|
||||||
|
assert result["error"] == "http_error"
|
||||||
|
assert result["status"] == 404
|
||||||
|
assert "not found" in result["message"].lower()
|
||||||
|
|
||||||
def test_fetch_data_builds_url(self):
|
def test_fetch_data_builds_url(self):
|
||||||
data = {"prompt": "hello"}
|
data = {"prompt": "hello"}
|
||||||
@@ -73,18 +82,23 @@ class TestFetchHelpers:
|
|||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderDynamic:
|
class TestProjectLoaderDynamic:
|
||||||
|
def _keys_meta(self, total=5):
|
||||||
|
return {"keys": [], "types": [], "total_sequences": total}
|
||||||
|
|
||||||
def test_load_dynamic_with_keys(self):
|
def test_load_dynamic_with_keys(self):
|
||||||
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys="prompt,seed,cfg"
|
output_keys="prompt,seed,cfg"
|
||||||
)
|
)
|
||||||
assert result[0] == "hello"
|
assert result[0] == 5 # total_sequences
|
||||||
assert result[1] == 42
|
assert result[1] == "hello"
|
||||||
assert result[2] == 1.5
|
assert result[2] == 42
|
||||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
assert result[3] == 1.5
|
||||||
|
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
|
||||||
|
|
||||||
def test_load_dynamic_with_json_encoded_keys(self):
|
def test_load_dynamic_with_json_encoded_keys(self):
|
||||||
"""JSON-encoded output_keys should be parsed correctly."""
|
"""JSON-encoded output_keys should be parsed correctly."""
|
||||||
@@ -92,13 +106,14 @@ class TestProjectLoaderDynamic:
|
|||||||
data = {"my,key": "comma_val", "normal": "ok"}
|
data = {"my,key": "comma_val", "normal": "ok"}
|
||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
keys_json = _json.dumps(["my,key", "normal"])
|
keys_json = _json.dumps(["my,key", "normal"])
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys=keys_json
|
output_keys=keys_json
|
||||||
)
|
)
|
||||||
assert result[0] == "comma_val"
|
assert result[1] == "comma_val"
|
||||||
assert result[1] == "ok"
|
assert result[2] == "ok"
|
||||||
|
|
||||||
def test_load_dynamic_type_coercion(self):
|
def test_load_dynamic_type_coercion(self):
|
||||||
"""output_types should coerce values to declared types."""
|
"""output_types should coerce values to declared types."""
|
||||||
@@ -107,41 +122,75 @@ class TestProjectLoaderDynamic:
|
|||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
||||||
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys=keys_json, output_types=types_json
|
output_keys=keys_json, output_types=types_json
|
||||||
)
|
)
|
||||||
assert result[0] == 42 # string "42" coerced to int
|
assert result[1] == 42 # string "42" coerced to int
|
||||||
assert result[1] == 1.5 # string "1.5" coerced to float
|
assert result[2] == 1.5 # string "1.5" coerced to float
|
||||||
assert result[2] == "hello" # string stays string
|
assert result[3] == "hello" # string stays string
|
||||||
|
|
||||||
def test_load_dynamic_empty_keys(self):
|
def test_load_dynamic_empty_keys(self):
|
||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys=""
|
output_keys=""
|
||||||
)
|
)
|
||||||
assert all(v == "" for v in result)
|
# Slot 0 is total_sequences (INT), rest are empty strings
|
||||||
|
assert result[0] == 5
|
||||||
|
assert all(v == "" for v in result[1:])
|
||||||
|
|
||||||
def test_load_dynamic_missing_key(self):
|
def test_load_dynamic_missing_key(self):
|
||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys="nonexistent"
|
output_keys="nonexistent"
|
||||||
)
|
)
|
||||||
assert result[0] == ""
|
assert result[1] == ""
|
||||||
|
|
||||||
def test_load_dynamic_bool_becomes_string(self):
|
def test_load_dynamic_bool_becomes_string(self):
|
||||||
node = ProjectLoaderDynamic()
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
||||||
result = node.load_dynamic(
|
result = node.load_dynamic(
|
||||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
output_keys="flag"
|
output_keys="flag"
|
||||||
)
|
)
|
||||||
assert result[0] == "true"
|
assert result[1] == "true"
|
||||||
|
|
||||||
|
def test_load_dynamic_returns_total_sequences(self):
|
||||||
|
"""total_sequences should be the first output from keys metadata."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
|
||||||
|
with patch("project_loader._fetch_data", return_value={}):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys=""
|
||||||
|
)
|
||||||
|
assert result[0] == 42
|
||||||
|
|
||||||
|
def test_load_dynamic_raises_on_network_error(self):
|
||||||
|
"""Network errors from _fetch_keys should raise RuntimeError."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||||
|
with patch("project_loader._fetch_keys", return_value=error_resp):
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
|
||||||
|
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||||
|
|
||||||
|
def test_load_dynamic_raises_on_data_fetch_error(self):
|
||||||
|
"""Network errors from _fetch_data should raise RuntimeError."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
|
||||||
|
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||||
|
|
||||||
def test_input_types_has_manager_url(self):
|
def test_input_types_has_manager_url(self):
|
||||||
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
||||||
@@ -154,88 +203,9 @@ class TestProjectLoaderDynamic:
|
|||||||
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderStandard:
|
|
||||||
def test_load_standard(self):
|
|
||||||
data = {
|
|
||||||
"general_prompt": "hello",
|
|
||||||
"general_negative": "bad",
|
|
||||||
"current_prompt": "specific",
|
|
||||||
"negative": "neg",
|
|
||||||
"camera": "pan",
|
|
||||||
"flf": 0.5,
|
|
||||||
"seed": 42,
|
|
||||||
"video file path": "/v.mp4",
|
|
||||||
"reference image path": "/r.png",
|
|
||||||
"flf image path": "/f.png",
|
|
||||||
}
|
|
||||||
node = ProjectLoaderStandard()
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png")
|
|
||||||
|
|
||||||
def test_load_standard_defaults(self):
|
|
||||||
node = ProjectLoaderStandard()
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
assert result[0] == "" # general_prompt
|
|
||||||
assert result[5] == 0.0 # flf
|
|
||||||
assert result[6] == 0 # seed
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderVACE:
|
|
||||||
def test_load_vace(self):
|
|
||||||
data = {
|
|
||||||
"general_prompt": "hello",
|
|
||||||
"general_negative": "bad",
|
|
||||||
"current_prompt": "specific",
|
|
||||||
"negative": "neg",
|
|
||||||
"camera": "pan",
|
|
||||||
"flf": 0.5,
|
|
||||||
"seed": 42,
|
|
||||||
"frame_to_skip": 81,
|
|
||||||
"input_a_frames": 16,
|
|
||||||
"input_b_frames": 16,
|
|
||||||
"reference path": "/ref",
|
|
||||||
"reference switch": 1,
|
|
||||||
"vace schedule": 2,
|
|
||||||
"video file path": "/v.mp4",
|
|
||||||
"reference image path": "/r.png",
|
|
||||||
}
|
|
||||||
node = ProjectLoaderVACE()
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_vace("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
assert result[7] == 81 # frame_to_skip
|
|
||||||
assert result[12] == 2 # vace_schedule
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderLoRA:
|
|
||||||
def test_load_loras(self):
|
|
||||||
data = {
|
|
||||||
"lora 1 high": "<lora:model1:1.0>",
|
|
||||||
"lora 1 low": "<lora:model1:0.5>",
|
|
||||||
"lora 2 high": "",
|
|
||||||
"lora 2 low": "",
|
|
||||||
"lora 3 high": "",
|
|
||||||
"lora 3 low": "",
|
|
||||||
}
|
|
||||||
node = ProjectLoaderLoRA()
|
|
||||||
with patch("project_loader._fetch_data", return_value=data):
|
|
||||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
assert result[0] == "<lora:model1:1.0>"
|
|
||||||
assert result[1] == "<lora:model1:0.5>"
|
|
||||||
|
|
||||||
def test_load_loras_empty(self):
|
|
||||||
node = ProjectLoaderLoRA()
|
|
||||||
with patch("project_loader._fetch_data", return_value={}):
|
|
||||||
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
|
|
||||||
assert all(v == "" for v in result)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNodeMappings:
|
class TestNodeMappings:
|
||||||
def test_mappings_exist(self):
|
def test_mappings_exist(self):
|
||||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||||
assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS
|
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
|
||||||
assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS
|
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
|
||||||
assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS
|
|
||||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4
|
|
||||||
|
|||||||
@@ -32,18 +32,55 @@ app.registerExtension({
|
|||||||
this.refreshDynamicOutputs();
|
this.refreshDynamicOutputs();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Auto-refresh with 500ms debounce on widget changes
|
||||||
|
this._refreshTimer = null;
|
||||||
|
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
|
||||||
|
for (const widgetName of autoRefreshWidgets) {
|
||||||
|
const w = this.widgets?.find(w => w.name === widgetName);
|
||||||
|
if (w) {
|
||||||
|
const origCallback = w.callback;
|
||||||
|
const node = this;
|
||||||
|
w.callback = function (...args) {
|
||||||
|
origCallback?.apply(this, args);
|
||||||
|
clearTimeout(node._refreshTimer);
|
||||||
|
node._refreshTimer = setTimeout(() => {
|
||||||
|
node.refreshDynamicOutputs();
|
||||||
|
}, 500);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
queueMicrotask(() => {
|
queueMicrotask(() => {
|
||||||
if (!this._configured) {
|
if (!this._configured) {
|
||||||
// New node (not loading) — remove the 32 Python default outputs
|
// New node (not loading) — remove the Python default outputs
|
||||||
|
// and add only the fixed total_sequences slot
|
||||||
while (this.outputs.length > 0) {
|
while (this.outputs.length > 0) {
|
||||||
this.removeOutput(0);
|
this.removeOutput(0);
|
||||||
}
|
}
|
||||||
|
this.addOutput("total_sequences", "INT");
|
||||||
this.setSize(this.computeSize());
|
this.setSize(this.computeSize());
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
nodeType.prototype._setStatus = function (status, message) {
|
||||||
|
const baseTitle = "Project Loader (Dynamic)";
|
||||||
|
if (status === "ok") {
|
||||||
|
this.title = baseTitle;
|
||||||
|
this.color = undefined;
|
||||||
|
this.bgcolor = undefined;
|
||||||
|
} else if (status === "error") {
|
||||||
|
this.title = baseTitle + " - ERROR";
|
||||||
|
this.color = "#ff4444";
|
||||||
|
this.bgcolor = "#331111";
|
||||||
|
if (message) this.title = baseTitle + ": " + message;
|
||||||
|
} else if (status === "loading") {
|
||||||
|
this.title = baseTitle + " - Loading...";
|
||||||
|
}
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
};
|
||||||
|
|
||||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
nodeType.prototype.refreshDynamicOutputs = async function () {
|
||||||
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
||||||
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
||||||
@@ -52,13 +89,20 @@ app.registerExtension({
|
|||||||
|
|
||||||
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
||||||
|
|
||||||
|
this._setStatus("loading");
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const resp = await api.fetchApi(
|
const resp = await api.fetchApi(
|
||||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!resp.ok) {
|
if (!resp.ok) {
|
||||||
console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs");
|
let errorMsg = `HTTP ${resp.status}`;
|
||||||
|
try {
|
||||||
|
const errData = await resp.json();
|
||||||
|
if (errData.message) errorMsg = errData.message;
|
||||||
|
} catch (_) {}
|
||||||
|
this._setStatus("error", errorMsg);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,7 +112,8 @@ app.registerExtension({
|
|||||||
|
|
||||||
// If the API returned an error or missing data, keep existing outputs and links intact
|
// If the API returned an error or missing data, keep existing outputs and links intact
|
||||||
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
||||||
console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types");
|
const errMsg = data.error ? data.message || data.error : "Missing keys/types";
|
||||||
|
this._setStatus("error", errMsg);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,14 +123,20 @@ app.registerExtension({
|
|||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||||
if (otWidget) otWidget.value = JSON.stringify(types);
|
if (otWidget) otWidget.value = JSON.stringify(types);
|
||||||
|
|
||||||
// Build a map of current output names to slot indices
|
// Slot 0 is always total_sequences (INT) — ensure it exists
|
||||||
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
|
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||||
|
}
|
||||||
|
this.outputs[0].type = "INT";
|
||||||
|
|
||||||
|
// Build a map of current dynamic output names to slot indices (skip slot 0)
|
||||||
const oldSlots = {};
|
const oldSlots = {};
|
||||||
for (let i = 0; i < this.outputs.length; i++) {
|
for (let i = 1; i < this.outputs.length; i++) {
|
||||||
oldSlots[this.outputs[i].name] = i;
|
oldSlots[this.outputs[i].name] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build new outputs, reusing existing slots to preserve links
|
// Build new dynamic outputs, reusing existing slots to preserve links
|
||||||
const newOutputs = [];
|
const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
|
||||||
for (let k = 0; k < keys.length; k++) {
|
for (let k = 0; k < keys.length; k++) {
|
||||||
const key = keys[k];
|
const key = keys[k];
|
||||||
const type = types[k] || "*";
|
const type = types[k] || "*";
|
||||||
@@ -122,10 +173,12 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this._setStatus("ok");
|
||||||
this.setSize(this.computeSize());
|
this.setSize(this.computeSize());
|
||||||
app.graph?.setDirtyCanvas(true, true);
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
||||||
|
this._setStatus("error", "Server unreachable");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -158,23 +211,59 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure slot 0 is total_sequences (INT)
|
||||||
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
|
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||||
|
// LiteGraph restores links AFTER onConfigure, so graph.links is
|
||||||
|
// empty here. Defer link fixup to a microtask that runs after the
|
||||||
|
// synchronous graph.configure() finishes (including link restoration).
|
||||||
|
// We must also rebuild output.links arrays because LiteGraph will
|
||||||
|
// place link IDs on the wrong outputs (shifted by the unshift above).
|
||||||
|
const node = this;
|
||||||
|
queueMicrotask(() => {
|
||||||
|
if (!node.graph) return;
|
||||||
|
// Clear all output.links — they were populated at old indices
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
output.links = null;
|
||||||
|
}
|
||||||
|
// Rebuild from graph.links with corrected origin_slot (+1)
|
||||||
|
for (const linkId in node.graph.links) {
|
||||||
|
const link = node.graph.links[linkId];
|
||||||
|
if (!link || link.origin_id !== node.id) continue;
|
||||||
|
link.origin_slot += 1;
|
||||||
|
const output = node.outputs[link.origin_slot];
|
||||||
|
if (output) {
|
||||||
|
if (!output.links) output.links = [];
|
||||||
|
output.links.push(link.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.outputs[0].type = "INT";
|
||||||
|
this.outputs[0].name = "total_sequences";
|
||||||
|
|
||||||
if (keys.length > 0) {
|
if (keys.length > 0) {
|
||||||
// On load, LiteGraph already restored serialized outputs with links.
|
// On load, LiteGraph already restored serialized outputs with links.
|
||||||
// Rename and set types to match stored state (preserves links).
|
// Dynamic outputs start at slot 1. Rename and set types to match stored state.
|
||||||
for (let i = 0; i < this.outputs.length && i < keys.length; i++) {
|
for (let i = 0; i < keys.length; i++) {
|
||||||
this.outputs[i].name = keys[i];
|
const slotIdx = i + 1; // offset by 1 for total_sequences
|
||||||
if (types[i]) this.outputs[i].type = types[i];
|
if (slotIdx < this.outputs.length) {
|
||||||
|
this.outputs[slotIdx].name = keys[i];
|
||||||
|
if (types[i]) this.outputs[slotIdx].type = types[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove any extra outputs beyond the key count
|
// Remove any extra outputs beyond keys + total_sequences
|
||||||
while (this.outputs.length > keys.length) {
|
while (this.outputs.length > keys.length + 1) {
|
||||||
this.removeOutput(this.outputs.length - 1);
|
this.removeOutput(this.outputs.length - 1);
|
||||||
}
|
}
|
||||||
} else if (this.outputs.length > 0) {
|
} else if (this.outputs.length > 1) {
|
||||||
// Widget values empty but serialized outputs exist — sync widgets
|
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
||||||
// from the outputs LiteGraph already restored (fallback).
|
// from the outputs LiteGraph already restored (fallback, skip slot 0).
|
||||||
if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name));
|
const dynamicOutputs = this.outputs.slice(1);
|
||||||
if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type));
|
if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
|
||||||
|
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
this.setSize(this.computeSize());
|
||||||
|
|||||||
Reference in New Issue
Block a user