diff --git a/api_routes.py b/api_routes.py index 6d5b42e..62f8512 100644 --- a/api_routes.py +++ b/api_routes.py @@ -55,13 +55,26 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]: def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() - data = db.query_sequence_data(name, file_name, seq) + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + data = db.get_sequence(df["id"], seq) if data is None: - raise HTTPException(status_code=404, detail="Sequence not found") + raise HTTPException(status_code=404, detail=f"Sequence {seq} not found") return data def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() - keys, types = db.query_sequence_keys(name, file_name, seq) - return {"keys": keys, "types": types} + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + keys, types = db.get_sequence_keys(df["id"], seq) + total = db.count_sequences(df["id"]) + return {"keys": keys, "types": types, "total_sequences": total} diff --git a/db.py b/db.py index cf428e5..e9088f9 100644 --- a/db.py +++ b/db.py @@ -182,6 +182,21 @@ class ProjectDB: ).fetchall() return [r["sequence_number"] for r in rows] + def count_sequences(self, data_file_id: int) -> int: + """Return the number of sequences for a data file.""" + row = self.conn.execute( + "SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return row["cnt"] + + def query_total_sequences(self, project_name: str, file_name: str) -> int: + """Return total sequence count by project and file names.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return 0 + return self.count_sequences(df["id"]) + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: """Returns (keys, types) for a sequence's data dict.""" data = self.get_sequence(data_file_id, sequence_number) diff --git a/project_loader.py b/project_loader.py index cc52d2c..6420517 100644 --- a/project_loader.py +++ b/project_loader.py @@ -39,13 +39,31 @@ def to_int(val: Any) -> int: def _fetch_json(url: str) -> dict: - """Fetch JSON from a URL using stdlib urllib.""" + """Fetch JSON from a URL using stdlib urllib. + + On error, returns a dict with an "error" key describing the failure. + """ try: with urllib.request.urlopen(url, timeout=5) as resp: return json.loads(resp.read()) - except (urllib.error.URLError, json.JSONDecodeError, OSError) as e: - logger.warning(f"Failed to fetch {url}: {e}") - return {} + except urllib.error.HTTPError as e: + # HTTPError is a subclass of URLError — must be caught first + body = "" + try: + raw = e.read() + detail = json.loads(raw) + body = detail.get("detail", str(raw, "utf-8", errors="replace")) + except Exception: + body = str(e) + logger.warning(f"HTTP {e.code} from {url}: {body}") + return {"error": "http_error", "status": e.code, "message": body} + except (urllib.error.URLError, OSError) as e: + reason = str(e.reason) if hasattr(e, "reason") else str(e) + logger.warning(f"Network error fetching {url}: {reason}") + return {"error": "network_error", "message": reason} + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON from {url}: {e}") + return {"error": "parse_error", "message": str(e)} def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: @@ -100,6 +118,9 @@ if PromptServer is not None: except (ValueError, TypeError): seq = 1 data = _fetch_keys(manager_url, project, file_name, seq) + if data.get("error") in ("http_error", "network_error", "parse_error"): + status = data.get("status", 502) + return web.json_response(data, status=status) return web.json_response(data) @@ -124,15 +145,25 @@ class ProjectLoaderDynamic: }, } - RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) - RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) FUNCTION = "load_dynamic" CATEGORY = "utils/json/project" OUTPUT_NODE = False def load_dynamic(self, manager_url, project_name, file_name, sequence_number, output_keys="", output_types=""): + # Fetch keys metadata (includes total_sequences count) + keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number) + if keys_meta.get("error") in ("http_error", "network_error", "parse_error"): + msg = keys_meta.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch project keys: {msg}") + total_sequences = keys_meta.get("total_sequences", 0) + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + msg = data.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch sequence data: {msg}") # Parse keys — try JSON array first, fall back to comma-split for compat keys = [] @@ -171,111 +202,14 @@ class ProjectLoaderDynamic: while len(results) < MAX_DYNAMIC_OUTPUTS: results.append("") - return tuple(results) - - -# ========================================== -# 1. STANDARD NODE (Project-based I2V) -# ========================================== - -class ProjectLoaderStandard: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_standard" - CATEGORY = "utils/json/project" - - def load_standard(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), str(data.get("video file path", "")), - str(data.get("reference image path", "")), str(data.get("flf image path", "")) - ) - - -# ========================================== -# 2. VACE NODE (Project-based) -# ========================================== - -class ProjectLoaderVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_vace" - CATEGORY = "utils/json/project" - - def load_vace(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)), - to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)), - str(data.get("reference path", "")), to_int(data.get("reference switch", 1)), - to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")), - str(data.get("reference image path", "")) - ) - - -# ========================================== -# 3. LoRA NODE (Project-based) -# ========================================== - -class ProjectLoaderLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_loras" - CATEGORY = "utils/json/project" - - def load_loras(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), - str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), - str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) - ) + return (total_sequences,) + tuple(results) # --- Mappings --- PROJECT_NODE_CLASS_MAPPINGS = { "ProjectLoaderDynamic": ProjectLoaderDynamic, - "ProjectLoaderStandard": ProjectLoaderStandard, - "ProjectLoaderVACE": ProjectLoaderVACE, - "ProjectLoaderLoRA": ProjectLoaderLoRA, } PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { "ProjectLoaderDynamic": "Project Loader (Dynamic)", - "ProjectLoaderStandard": "Project Loader (Standard/I2V)", - "ProjectLoaderVACE": "Project Loader (VACE Full)", - "ProjectLoaderLoRA": "Project Loader (LoRAs)", } diff --git a/tests/test_db.py b/tests/test_db.py index 341edcb..bea102f 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -172,6 +172,25 @@ class TestSequences: db.delete_sequences_for_file(df_id) assert db.list_sequences(df_id) == [] + def test_count_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.count_sequences(df_id) == 0 + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.upsert_sequence(df_id, 3, {"c": 3}) + assert db.count_sequences(df_id) == 3 + + def test_query_total_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + assert db.query_total_sequences("p1", "batch") == 2 + + def test_query_total_sequences_nonexistent(self, db): + assert db.query_total_sequences("nope", "nope") == 0 + # ------------------------------------------------------------------ # History trees diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 58e76ee..dab80b2 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -6,9 +6,6 @@ import pytest from project_loader import ( ProjectLoaderDynamic, - ProjectLoaderStandard, - ProjectLoaderVACE, - ProjectLoaderLoRA, _fetch_json, _fetch_data, _fetch_keys, @@ -32,11 +29,23 @@ class TestFetchHelpers: result = _fetch_json("http://example.com/api") assert result == data - def test_fetch_json_failure(self): - import urllib.error + def test_fetch_json_network_error(self): with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): result = _fetch_json("http://example.com/api") - assert result == {} + assert result["error"] == "network_error" + assert "connection refused" in result["message"] + + def test_fetch_json_http_error(self): + import urllib.error + err = urllib.error.HTTPError( + "http://example.com/api", 404, "Not Found", {}, + BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode()) + ) + with patch("project_loader.urllib.request.urlopen", side_effect=err): + result = _fetch_json("http://example.com/api") + assert result["error"] == "http_error" + assert result["status"] == 404 + assert "not found" in result["message"].lower() def test_fetch_data_builds_url(self): data = {"prompt": "hello"} @@ -73,18 +82,23 @@ class TestFetchHelpers: class TestProjectLoaderDynamic: + def _keys_meta(self, total=5): + return {"keys": [], "types": [], "total_sequences": total} + def test_load_dynamic_with_keys(self): data = {"prompt": "hello", "seed": 42, "cfg": 1.5} node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="prompt,seed,cfg" - ) - assert result[0] == "hello" - assert result[1] == 42 - assert result[2] == 1.5 - assert len(result) == MAX_DYNAMIC_OUTPUTS + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="prompt,seed,cfg" + ) + assert result[0] == 5 # total_sequences + assert result[1] == "hello" + assert result[2] == 42 + assert result[3] == 1.5 + assert len(result) == MAX_DYNAMIC_OUTPUTS + 1 def test_load_dynamic_with_json_encoded_keys(self): """JSON-encoded output_keys should be parsed correctly.""" @@ -92,13 +106,14 @@ class TestProjectLoaderDynamic: data = {"my,key": "comma_val", "normal": "ok"} node = ProjectLoaderDynamic() keys_json = _json.dumps(["my,key", "normal"]) - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys=keys_json - ) - assert result[0] == "comma_val" - assert result[1] == "ok" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json + ) + assert result[1] == "comma_val" + assert result[2] == "ok" def test_load_dynamic_type_coercion(self): """output_types should coerce values to declared types.""" @@ -107,41 +122,75 @@ class TestProjectLoaderDynamic: node = ProjectLoaderDynamic() keys_json = _json.dumps(["seed", "cfg", "prompt"]) types_json = _json.dumps(["INT", "FLOAT", "STRING"]) - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys=keys_json, output_types=types_json - ) - assert result[0] == 42 # string "42" coerced to int - assert result[1] == 1.5 # string "1.5" coerced to float - assert result[2] == "hello" # string stays string + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json, output_types=types_json + ) + assert result[1] == 42 # string "42" coerced to int + assert result[2] == 1.5 # string "1.5" coerced to float + assert result[3] == "hello" # string stays string def test_load_dynamic_empty_keys(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="" - ) - assert all(v == "" for v in result) + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + # Slot 0 is total_sequences (INT), rest are empty strings + assert result[0] == 5 + assert all(v == "" for v in result[1:]) def test_load_dynamic_missing_key(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="nonexistent" - ) - assert result[0] == "" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="nonexistent" + ) + assert result[1] == "" def test_load_dynamic_bool_becomes_string(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"flag": True}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="flag" - ) - assert result[0] == "true" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"flag": True}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="flag" + ) + assert result[1] == "true" + + def test_load_dynamic_returns_total_sequences(self): + """total_sequences should be the first output from keys metadata.""" + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}): + with patch("project_loader._fetch_data", return_value={}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + assert result[0] == 42 + + def test_load_dynamic_raises_on_network_error(self): + """Network errors from _fetch_keys should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_keys", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch project keys"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) + + def test_load_dynamic_raises_on_data_fetch_error(self): + """Network errors from _fetch_data should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"} + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch sequence data"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) def test_input_types_has_manager_url(self): inputs = ProjectLoaderDynamic.INPUT_TYPES() @@ -154,88 +203,9 @@ class TestProjectLoaderDynamic: assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" -class TestProjectLoaderStandard: - def test_load_standard(self): - data = { - "general_prompt": "hello", - "general_negative": "bad", - "current_prompt": "specific", - "negative": "neg", - "camera": "pan", - "flf": 0.5, - "seed": 42, - "video file path": "/v.mp4", - "reference image path": "/r.png", - "flf image path": "/f.png", - } - node = ProjectLoaderStandard() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) - assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png") - - def test_load_standard_defaults(self): - node = ProjectLoaderStandard() - with patch("project_loader._fetch_data", return_value={}): - result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) - assert result[0] == "" # general_prompt - assert result[5] == 0.0 # flf - assert result[6] == 0 # seed - - -class TestProjectLoaderVACE: - def test_load_vace(self): - data = { - "general_prompt": "hello", - "general_negative": "bad", - "current_prompt": "specific", - "negative": "neg", - "camera": "pan", - "flf": 0.5, - "seed": 42, - "frame_to_skip": 81, - "input_a_frames": 16, - "input_b_frames": 16, - "reference path": "/ref", - "reference switch": 1, - "vace schedule": 2, - "video file path": "/v.mp4", - "reference image path": "/r.png", - } - node = ProjectLoaderVACE() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_vace("http://localhost:8080", "proj1", "batch", 1) - assert result[7] == 81 # frame_to_skip - assert result[12] == 2 # vace_schedule - - -class TestProjectLoaderLoRA: - def test_load_loras(self): - data = { - "lora 1 high": "", - "lora 1 low": "", - "lora 2 high": "", - "lora 2 low": "", - "lora 3 high": "", - "lora 3 low": "", - } - node = ProjectLoaderLoRA() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) - assert result[0] == "" - assert result[1] == "" - - def test_load_loras_empty(self): - node = ProjectLoaderLoRA() - with patch("project_loader._fetch_data", return_value={}): - result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) - assert all(v == "" for v in result) - - class TestNodeMappings: def test_mappings_exist(self): from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS - assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1 diff --git a/web/project_dynamic.js b/web/project_dynamic.js index f8e7634..783b664 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -32,18 +32,55 @@ app.registerExtension({ this.refreshDynamicOutputs(); }); + // Auto-refresh with 500ms debounce on widget changes + this._refreshTimer = null; + const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"]; + for (const widgetName of autoRefreshWidgets) { + const w = this.widgets?.find(w => w.name === widgetName); + if (w) { + const origCallback = w.callback; + const node = this; + w.callback = function (...args) { + origCallback?.apply(this, args); + clearTimeout(node._refreshTimer); + node._refreshTimer = setTimeout(() => { + node.refreshDynamicOutputs(); + }, 500); + }; + } + } + queueMicrotask(() => { if (!this._configured) { - // New node (not loading) — remove the 32 Python default outputs + // New node (not loading) — remove the Python default outputs + // and add only the fixed total_sequences slot while (this.outputs.length > 0) { this.removeOutput(0); } + this.addOutput("total_sequences", "INT"); this.setSize(this.computeSize()); app.graph?.setDirtyCanvas(true, true); } }); }; + nodeType.prototype._setStatus = function (status, message) { + const baseTitle = "Project Loader (Dynamic)"; + if (status === "ok") { + this.title = baseTitle; + this.color = undefined; + this.bgcolor = undefined; + } else if (status === "error") { + this.title = baseTitle + " - ERROR"; + this.color = "#ff4444"; + this.bgcolor = "#331111"; + if (message) this.title = baseTitle + ": " + message; + } else if (status === "loading") { + this.title = baseTitle + " - Loading..."; + } + app.graph?.setDirtyCanvas(true, true); + }; + nodeType.prototype.refreshDynamicOutputs = async function () { const urlWidget = this.widgets?.find(w => w.name === "manager_url"); const projectWidget = this.widgets?.find(w => w.name === "project_name"); @@ -52,13 +89,20 @@ app.registerExtension({ if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; + this._setStatus("loading"); + try { const resp = await api.fetchApi( `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` ); if (!resp.ok) { - console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs"); + let errorMsg = `HTTP ${resp.status}`; + try { + const errData = await resp.json(); + if (errData.message) errorMsg = errData.message; + } catch (_) {} + this._setStatus("error", errorMsg); return; } @@ -68,7 +112,8 @@ app.registerExtension({ // If the API returned an error or missing data, keep existing outputs and links intact if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { - console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types"); + const errMsg = data.error ? data.message || data.error : "Missing keys/types"; + this._setStatus("error", errMsg); return; } @@ -78,14 +123,20 @@ app.registerExtension({ const otWidget = this.widgets?.find(w => w.name === "output_types"); if (otWidget) otWidget.value = JSON.stringify(types); - // Build a map of current output names to slot indices + // Slot 0 is always total_sequences (INT) — ensure it exists + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + } + this.outputs[0].type = "INT"; + + // Build a map of current dynamic output names to slot indices (skip slot 0) const oldSlots = {}; - for (let i = 0; i < this.outputs.length; i++) { + for (let i = 1; i < this.outputs.length; i++) { oldSlots[this.outputs[i].name] = i; } - // Build new outputs, reusing existing slots to preserve links - const newOutputs = []; + // Build new dynamic outputs, reusing existing slots to preserve links + const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0 for (let k = 0; k < keys.length; k++) { const key = keys[k]; const type = types[k] || "*"; @@ -122,10 +173,12 @@ app.registerExtension({ } } + this._setStatus("ok"); this.setSize(this.computeSize()); app.graph?.setDirtyCanvas(true, true); } catch (e) { console.error("[ProjectLoaderDynamic] Refresh failed:", e); + this._setStatus("error", "Server unreachable"); } }; @@ -158,23 +211,59 @@ app.registerExtension({ } } + // Ensure slot 0 is total_sequences (INT) + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + // LiteGraph restores links AFTER onConfigure, so graph.links is + // empty here. Defer link fixup to a microtask that runs after the + // synchronous graph.configure() finishes (including link restoration). + // We must also rebuild output.links arrays because LiteGraph will + // place link IDs on the wrong outputs (shifted by the unshift above). + const node = this; + queueMicrotask(() => { + if (!node.graph) return; + // Clear all output.links — they were populated at old indices + for (const output of node.outputs) { + output.links = null; + } + // Rebuild from graph.links with corrected origin_slot (+1) + for (const linkId in node.graph.links) { + const link = node.graph.links[linkId]; + if (!link || link.origin_id !== node.id) continue; + link.origin_slot += 1; + const output = node.outputs[link.origin_slot]; + if (output) { + if (!output.links) output.links = []; + output.links.push(link.id); + } + } + app.graph?.setDirtyCanvas(true, true); + }); + } + this.outputs[0].type = "INT"; + this.outputs[0].name = "total_sequences"; + if (keys.length > 0) { // On load, LiteGraph already restored serialized outputs with links. - // Rename and set types to match stored state (preserves links). - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i]; - if (types[i]) this.outputs[i].type = types[i]; + // Dynamic outputs start at slot 1. Rename and set types to match stored state. + for (let i = 0; i < keys.length; i++) { + const slotIdx = i + 1; // offset by 1 for total_sequences + if (slotIdx < this.outputs.length) { + this.outputs[slotIdx].name = keys[i]; + if (types[i]) this.outputs[slotIdx].type = types[i]; + } } - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { + // Remove any extra outputs beyond keys + total_sequences + while (this.outputs.length > keys.length + 1) { this.removeOutput(this.outputs.length - 1); } - } else if (this.outputs.length > 0) { - // Widget values empty but serialized outputs exist — sync widgets - // from the outputs LiteGraph already restored (fallback). - if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name)); - if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type)); + } else if (this.outputs.length > 1) { + // Widget values empty but serialized dynamic outputs exist — sync widgets + // from the outputs LiteGraph already restored (fallback, skip slot 0). + const dynamicOutputs = this.outputs.slice(1); + if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name)); + if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type)); } this.setSize(this.computeSize());