From c15bec98ce664621554283e4da99c7883ba4b8dc Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:12:05 +0100 Subject: [PATCH 01/14] Add SQLite project database + ComfyUI connector nodes - db.py: ProjectDB class with SQLite schema (projects, data_files, sequences, history_trees), WAL mode, CRUD, import, and query helpers - api_routes.py: REST API endpoints on NiceGUI/FastAPI for ComfyUI to query project data over the network - project_loader.py: ComfyUI nodes (ProjectLoaderDynamic, Standard, VACE, LoRA) that fetch data from NiceGUI REST API via HTTP - web/project_dynamic.js: Frontend JS for dynamic project loader node - tab_projects_ng.py: Projects management tab in NiceGUI UI - state.py: Added db, current_project, db_enabled fields - main.py: DB init, API route registration, projects tab - utils.py: sync_to_db() dual-write helper - tab_batch_ng.py, tab_raw_ng.py, tab_timeline_ng.py: dual-write sync calls after save_json when project DB is enabled - __init__.py: Merged project node class mappings - tests/test_db.py: 30 tests for database layer - tests/test_project_loader.py: 17 tests for ComfyUI connector nodes Co-Authored-By: Claude Opus 4.6 --- __init__.py | 4 + api_routes.py | 67 ++++++++ db.py | 285 ++++++++++++++++++++++++++++++++++ main.py | 26 ++++ project_loader.py | 255 +++++++++++++++++++++++++++++++ state.py | 5 + tab_batch_ng.py | 16 +- tab_projects_ng.py | 161 ++++++++++++++++++++ tab_raw_ng.py | 4 +- tab_timeline_ng.py | 14 +- tests/test_db.py | 286 +++++++++++++++++++++++++++++++++++ tests/test_project_loader.py | 201 ++++++++++++++++++++++++ utils.py | 37 +++++ web/project_dynamic.js | 139 +++++++++++++++++ 14 files changed, 1495 insertions(+), 5 deletions(-) create mode 100644 api_routes.py create mode 100644 db.py create mode 100644 project_loader.py create mode 100644 tab_projects_ng.py create mode 100644 tests/test_db.py create mode 100644 tests/test_project_loader.py create mode 100644 web/project_dynamic.js diff --git a/__init__.py b/__init__.py index 43198c8..5e657ee 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,8 @@ from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + +NODE_CLASS_MAPPINGS.update(PROJECT_NODE_CLASS_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) WEB_DIRECTORY = "./web" diff --git a/api_routes.py b/api_routes.py new file mode 100644 index 0000000..36e84cd --- /dev/null +++ b/api_routes.py @@ -0,0 +1,67 @@ +"""REST API endpoints for ComfyUI to query project data from SQLite. + +All endpoints are read-only. Mounted on the NiceGUI/FastAPI server. +""" + +import logging +from typing import Any + +from fastapi import HTTPException, Query +from nicegui import app + +from db import ProjectDB + +logger = logging.getLogger(__name__) + +# The DB instance is set by register_api_routes() +_db: ProjectDB | None = None + + +def register_api_routes(db: ProjectDB) -> None: + """Register all REST API routes with the NiceGUI/FastAPI app.""" + global _db + _db = db + + app.add_api_route("/api/projects", _list_projects, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"]) + + +def _get_db() -> ProjectDB: + if _db is None: + raise HTTPException(status_code=503, detail="Database not initialized") + return _db + + +async def _list_projects() -> dict[str, Any]: + db = _get_db() + projects = db.list_projects() + return {"projects": [p["name"] for p in projects]} + + +async def _list_files(name: str) -> dict[str, Any]: + db = _get_db() + files = db.list_project_files(name) + return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]} + + +async def _list_sequences(name: str, file_name: str) -> dict[str, Any]: + db = _get_db() + seqs = db.list_project_sequences(name, file_name) + return {"sequences": seqs} + + +async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + data = db.query_sequence_data(name, file_name, seq) + if data is None: + raise HTTPException(status_code=404, detail="Sequence not found") + return data + + +async def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + keys, types = db.query_sequence_keys(name, file_name, seq) + return {"keys": keys, "types": types} diff --git a/db.py b/db.py new file mode 100644 index 0000000..11efc7d --- /dev/null +++ b/db.py @@ -0,0 +1,285 @@ +import json +import logging +import sqlite3 +import time +from pathlib import Path +from typing import Any + +from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE + +logger = logging.getLogger(__name__) + +DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db" + +SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + folder_path TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + created_at REAL NOT NULL, + updated_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS data_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + name TEXT NOT NULL, + data_type TEXT NOT NULL DEFAULT 'generic', + top_level TEXT NOT NULL DEFAULT '{}', + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + UNIQUE(project_id, name) +); + +CREATE TABLE IF NOT EXISTS sequences ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE, + sequence_number INTEGER NOT NULL, + data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL, + UNIQUE(data_file_id, sequence_number) +); + +CREATE TABLE IF NOT EXISTS history_trees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE, + tree_data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL +); +""" + + +class ProjectDB: + """SQLite database for project-based data management.""" + + def __init__(self, db_path: str | Path | None = None): + self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self.conn.execute("PRAGMA journal_mode=WAL") + self.conn.execute("PRAGMA foreign_keys=ON") + self.conn.executescript(SCHEMA_SQL) + self.conn.commit() + + def close(self): + self.conn.close() + + # ------------------------------------------------------------------ + # Projects CRUD + # ------------------------------------------------------------------ + + def create_project(self, name: str, folder_path: str, description: str = "") -> int: + now = time.time() + cur = self.conn.execute( + "INSERT INTO projects (name, folder_path, description, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?)", + (name, folder_path, description, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_projects(self) -> list[dict]: + rows = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects ORDER BY name" + ).fetchall() + return [dict(r) for r in rows] + + def get_project(self, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects WHERE name = ?", + (name,), + ).fetchone() + return dict(row) if row else None + + def delete_project(self, name: str) -> bool: + cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,)) + self.conn.commit() + return cur.rowcount > 0 + + # ------------------------------------------------------------------ + # Data files + # ------------------------------------------------------------------ + + def create_data_file( + self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None + ) -> int: + now = time.time() + tl = json.dumps(top_level or {}) + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, name, data_type, tl, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_data_files(self, project_id: int) -> list[dict]: + rows = self.conn.execute( + "SELECT id, project_id, name, data_type, created_at, updated_at " + "FROM data_files WHERE project_id = ? ORDER BY name", + (project_id,), + ).fetchall() + return [dict(r) for r in rows] + + def get_data_file(self, project_id: int, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, project_id, name, data_type, top_level, created_at, updated_at " + "FROM data_files WHERE project_id = ? AND name = ?", + (project_id, name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None: + row = self.conn.execute( + "SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, " + "df.created_at, df.updated_at " + "FROM data_files df JOIN projects p ON df.project_id = p.id " + "WHERE p.name = ? AND df.name = ?", + (project_name, file_name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + # ------------------------------------------------------------------ + # Sequences + # ------------------------------------------------------------------ + + def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (data_file_id, sequence_number, json.dumps(data), now), + ) + self.conn.commit() + + def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None: + row = self.conn.execute( + "SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?", + (data_file_id, sequence_number), + ).fetchone() + return json.loads(row["data"]) if row else None + + def list_sequences(self, data_file_id: int) -> list[int]: + rows = self.conn.execute( + "SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number", + (data_file_id,), + ).fetchall() + return [r["sequence_number"] for r in rows] + + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: + """Returns (keys, types) for a sequence's data dict.""" + data = self.get_sequence(data_file_id, sequence_number) + if not data: + return [], [] + keys = [] + types = [] + for k, v in data.items(): + keys.append(k) + if isinstance(v, bool): + types.append("STRING") + elif isinstance(v, int): + types.append("INT") + elif isinstance(v, float): + types.append("FLOAT") + else: + types.append("STRING") + return keys, types + + def delete_sequences_for_file(self, data_file_id: int) -> None: + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,)) + self.conn.commit() + + # ------------------------------------------------------------------ + # History trees + # ------------------------------------------------------------------ + + def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (data_file_id, json.dumps(tree_data), now), + ) + self.conn.commit() + + def get_history_tree(self, data_file_id: int) -> dict | None: + row = self.conn.execute( + "SELECT tree_data FROM history_trees WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return json.loads(row["tree_data"]) if row else None + + # ------------------------------------------------------------------ + # Import + # ------------------------------------------------------------------ + + def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: + """Import a JSON file into the database, splitting batch_data into sequences.""" + json_path = Path(json_path) + data, _ = load_json(json_path) + file_name = json_path.stem + + # Extract top-level keys that aren't batch_data or history_tree + top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + + df_id = self.create_data_file(project_id, file_name, data_type, top_level) + + # Import sequences from batch_data + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + for item in batch_data: + seq_num = int(item.get("sequence_number", 0)) + self.upsert_sequence(df_id, seq_num, item) + + # Import history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + self.save_history_tree(df_id, history_tree) + + return df_id + + # ------------------------------------------------------------------ + # Query helpers (for REST API) + # ------------------------------------------------------------------ + + def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None: + """Query a single sequence by project name, file name, and sequence number.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return None + return self.get_sequence(df["id"], sequence_number) + + def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]: + """Query keys and types for a sequence.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [], [] + return self.get_sequence_keys(df["id"], sequence_number) + + def list_project_files(self, project_name: str) -> list[dict]: + """List data files for a project by name.""" + proj = self.get_project(project_name) + if not proj: + return [] + return self.list_data_files(proj["id"]) + + def list_project_sequences(self, project_name: str, file_name: str) -> list[int]: + """List sequence numbers for a file in a project.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [] + return self.list_sequences(df["id"]) diff --git a/main.py b/main.py index b6455fc..5936aa8 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import json +import logging from pathlib import Path from nicegui import ui @@ -14,6 +15,11 @@ from tab_batch_ng import render_batch_processor from tab_timeline_ng import render_timeline_tab from tab_raw_ng import render_raw_editor from tab_comfy_ng import render_comfy_monitor +from tab_projects_ng import render_projects_tab +from db import ProjectDB +from api_routes import register_api_routes + +logger = logging.getLogger(__name__) @ui.page('/') @@ -156,7 +162,17 @@ def index(): config=config, current_dir=Path(config.get('last_dir', str(Path.cwd()))), snippets=load_snippets(), + db_enabled=config.get('db_enabled', False), + current_project=config.get('current_project', ''), ) + + # Initialize project database + try: + state.db = ProjectDB() + except Exception as e: + logger.warning(f"Failed to initialize ProjectDB: {e}") + state.db = None + dual_pane = {'active': False, 'state': None} # ------------------------------------------------------------------ @@ -178,6 +194,7 @@ def index(): ui.tab('batch', label='Batch Processor') ui.tab('timeline', label='Timeline') ui.tab('raw', label='Raw Editor') + ui.tab('projects', label='Projects') with ui.tab_panels(tabs, value='batch').classes('w-full'): with ui.tab_panel('batch'): @@ -186,6 +203,8 @@ def index(): render_timeline_tab(state) with ui.tab_panel('raw'): render_raw_editor(state) + with ui.tab_panel('projects'): + render_projects_tab(state) if state.show_comfy_monitor: ui.separator() @@ -481,4 +500,11 @@ def render_sidebar(state: AppState, dual_pane: dict): ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) +# Register REST API routes for ComfyUI connectivity +try: + _api_db = ProjectDB() + register_api_routes(_api_db) +except Exception as e: + logger.warning(f"Failed to register API routes: {e}") + ui.run(title='AI Settings Manager', port=8080, reload=True) diff --git a/project_loader.py b/project_loader.py new file mode 100644 index 0000000..2adc6c3 --- /dev/null +++ b/project_loader.py @@ -0,0 +1,255 @@ +import json +import logging +import urllib.request +import urllib.error +from typing import Any + +logger = logging.getLogger(__name__) + +MAX_DYNAMIC_OUTPUTS = 32 + + +class AnyType(str): + """Universal connector type that matches any ComfyUI type.""" + def __ne__(self, __value: object) -> bool: + return False + +any_type = AnyType("*") + + +try: + from server import PromptServer + from aiohttp import web +except ImportError: + PromptServer = None + + +def to_float(val: Any) -> float: + try: + return float(val) + except (ValueError, TypeError): + return 0.0 + +def to_int(val: Any) -> int: + try: + return int(float(val)) + except (ValueError, TypeError): + return 0 + + +def _fetch_json(url: str) -> dict: + """Fetch JSON from a URL using stdlib urllib.""" + try: + with urllib.request.urlopen(url, timeout=5) as resp: + return json.loads(resp.read()) + except (urllib.error.URLError, json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to fetch {url}: {e}") + return {} + + +def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch sequence data from the NiceGUI REST API.""" + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}" + return _fetch_json(url) + + +def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch keys/types from the NiceGUI REST API.""" + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}" + return _fetch_json(url) + + +# --- ComfyUI-side proxy endpoints (for frontend JS) --- +if PromptServer is not None: + @PromptServer.instance.routes.get("/json_manager/list_projects") + async def list_projects_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + url = f"{manager_url.rstrip('/')}/api/projects" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_files") + async def list_project_files_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_sequences") + async def list_project_sequences_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + file_name = request.query.get("file", "") + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/get_project_keys") + async def get_project_keys_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + file_name = request.query.get("file", "") + try: + seq = int(request.query.get("seq", "1")) + except (ValueError, TypeError): + seq = 1 + data = _fetch_keys(manager_url, project, file_name, seq) + return web.json_response(data) + + +# ========================================== +# 0. DYNAMIC NODE (Project-based) +# ========================================== + +class ProjectLoaderDynamic: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + "optional": { + "output_keys": ("STRING", {"default": ""}), + "output_types": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) + FUNCTION = "load_dynamic" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + def load_dynamic(self, manager_url, project_name, file_name, sequence_number, + output_keys="", output_types=""): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + + keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] + + results = [] + for key in keys: + val = data.get(key, "") + if isinstance(val, bool): + results.append(str(val).lower()) + elif isinstance(val, int): + results.append(val) + elif isinstance(val, float): + results.append(val) + else: + results.append(str(val)) + + while len(results) < MAX_DYNAMIC_OUTPUTS: + results.append("") + + return tuple(results) + + +# ========================================== +# 1. STANDARD NODE (Project-based I2V) +# ========================================== + +class ProjectLoaderStandard: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") + RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") + FUNCTION = "load_standard" + CATEGORY = "utils/json/project" + + def load_standard(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("general_prompt", "")), str(data.get("general_negative", "")), + str(data.get("current_prompt", "")), str(data.get("negative", "")), + str(data.get("camera", "")), to_float(data.get("flf", 0.0)), + to_int(data.get("seed", 0)), str(data.get("video file path", "")), + str(data.get("reference image path", "")), str(data.get("flf image path", "")) + ) + + +# ========================================== +# 2. VACE NODE (Project-based) +# ========================================== + +class ProjectLoaderVACE: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") + RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") + FUNCTION = "load_vace" + CATEGORY = "utils/json/project" + + def load_vace(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("general_prompt", "")), str(data.get("general_negative", "")), + str(data.get("current_prompt", "")), str(data.get("negative", "")), + str(data.get("camera", "")), to_float(data.get("flf", 0.0)), + to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)), + to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)), + str(data.get("reference path", "")), to_int(data.get("reference switch", 1)), + to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")), + str(data.get("reference image path", "")) + ) + + +# ========================================== +# 3. LoRA NODE (Project-based) +# ========================================== + +class ProjectLoaderLoRA: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") + RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") + FUNCTION = "load_loras" + CATEGORY = "utils/json/project" + + def load_loras(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), + str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), + str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) + ) + + +# --- Mappings --- +PROJECT_NODE_CLASS_MAPPINGS = { + "ProjectLoaderDynamic": ProjectLoaderDynamic, + "ProjectLoaderStandard": ProjectLoaderStandard, + "ProjectLoaderVACE": ProjectLoaderVACE, + "ProjectLoaderLoRA": ProjectLoaderLoRA, +} + +PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { + "ProjectLoaderDynamic": "Project Loader (Dynamic)", + "ProjectLoaderStandard": "Project Loader (Standard/I2V)", + "ProjectLoaderVACE": "Project Loader (VACE Full)", + "ProjectLoaderLoRA": "Project Loader (LoRAs)", +} diff --git a/state.py b/state.py index e4aeab4..891a14e 100644 --- a/state.py +++ b/state.py @@ -17,6 +17,11 @@ class AppState: live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True + # Project DB fields + db: Any = None + current_project: str = "" + db_enabled: bool = False + # Set at runtime by main.py / tab_comfy_ng.py _render_main: Any = None _load_file: Callable | None = None diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 34845f6..36abee8 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -6,7 +6,7 @@ from nicegui import ui from state import AppState from utils import ( - DEFAULTS, save_json, load_json, + DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) from history_tree import HistoryTree @@ -161,6 +161,8 @@ def render_batch_processor(state: AppState): new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {}, KEY_PROMPT_HISTORY: []} save_json(new_path, new_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, new_path, new_data) ui.notify(f'Created {new_name}', type='positive') ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch) @@ -215,6 +217,8 @@ def render_batch_processor(state: AppState): batch_list.append(new_item) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) render_sequence_list.refresh() with ui.row().classes('q-mt-sm'): @@ -250,6 +254,8 @@ def render_batch_processor(state: AppState): batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0))) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Sorted by sequence number!', type='positive') render_sequence_list.refresh() @@ -289,6 +295,8 @@ def render_batch_processor(state: AppState): htree.commit(snapshot_payload, note=note) data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.restored_indicator = None commit_input.set_value('') ui.notify('Batch Saved & Snapshot Created!', type='positive') @@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, def commit(message=None): data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) if message: ui.notify(message, type='positive') refresh_list.refresh() @@ -567,6 +577,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): shifted += 1 data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive') refresh_list.refresh() @@ -712,6 +724,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Updated {len(targets)} sequences', type='positive') if refresh_list: refresh_list.refresh() diff --git a/tab_projects_ng.py b/tab_projects_ng.py new file mode 100644 index 0000000..afa8422 --- /dev/null +++ b/tab_projects_ng.py @@ -0,0 +1,161 @@ +import logging +from pathlib import Path + +from nicegui import ui + +from state import AppState +from db import ProjectDB +from utils import save_config, sync_to_db, KEY_BATCH_DATA + +logger = logging.getLogger(__name__) + + +def render_projects_tab(state: AppState): + """Render the Projects management tab.""" + + # --- DB toggle --- + def on_db_toggle(e): + state.db_enabled = e.value + state.config['db_enabled'] = e.value + save_config(state.current_dir, state.config.get('favorites', []), state.config) + render_project_content.refresh() + + ui.switch('Enable Project Database', value=state.db_enabled, + on_change=on_db_toggle).classes('q-mb-md') + + @ui.refreshable + def render_project_content(): + if not state.db_enabled: + ui.label('Project database is disabled. Enable it above to manage projects.').classes( + 'text-caption q-pa-md') + return + + if not state.db: + ui.label('Database not initialized.').classes('text-warning q-pa-md') + return + + # --- Create project form --- + with ui.card().classes('w-full q-pa-md q-mb-md'): + ui.label('Create New Project').classes('section-header') + name_input = ui.input('Project Name', placeholder='my_project').classes('w-full') + desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full') + + def create_project(): + name = name_input.value.strip() + if not name: + ui.notify('Please enter a project name', type='warning') + return + try: + state.db.create_project(name, str(state.current_dir), desc_input.value.strip()) + name_input.set_value('') + desc_input.set_value('') + ui.notify(f'Created project "{name}"', type='positive') + render_project_list.refresh() + except Exception as e: + ui.notify(f'Error: {e}', type='negative') + + ui.button('Create Project', icon='add', on_click=create_project).classes('w-full') + + # --- Active project indicator --- + if state.current_project: + ui.label(f'Active Project: {state.current_project}').classes( + 'text-bold text-primary q-pa-sm') + + # --- Project list --- + @ui.refreshable + def render_project_list(): + projects = state.db.list_projects() + if not projects: + ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md') + return + + for proj in projects: + is_active = proj['name'] == state.current_project + card_style = 'border-left: 3px solid var(--accent);' if is_active else '' + + with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style): + with ui.row().classes('w-full items-center'): + with ui.column().classes('col'): + ui.label(proj['name']).classes('text-bold') + if proj['description']: + ui.label(proj['description']).classes('text-caption') + ui.label(f'Path: {proj["folder_path"]}').classes('text-caption') + files = state.db.list_data_files(proj['id']) + ui.label(f'{len(files)} data file(s)').classes('text-caption') + + with ui.row().classes('q-gutter-xs'): + if not is_active: + def activate(name=proj['name']): + state.current_project = name + state.config['current_project'] = name + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify(f'Activated project "{name}"', type='positive') + render_project_list.refresh() + + ui.button('Activate', icon='check_circle', + on_click=activate).props('flat dense color=primary') + else: + def deactivate(): + state.current_project = '' + state.config['current_project'] = '' + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify('Deactivated project', type='info') + render_project_list.refresh() + + ui.button('Deactivate', icon='cancel', + on_click=deactivate).props('flat dense') + + def import_folder(pid=proj['id'], pname=proj['name']): + _import_folder(state, pid, pname, render_project_list) + + ui.button('Import Folder', icon='folder_open', + on_click=import_folder).props('flat dense') + + def delete_proj(name=proj['name']): + state.db.delete_project(name) + if state.current_project == name: + state.current_project = '' + ui.notify(f'Deleted project "{name}"', type='positive') + render_project_list.refresh() + + ui.button(icon='delete', + on_click=delete_proj).props('flat dense color=negative') + + render_project_list() + + render_project_content() + + +def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn): + """Bulk import all .json files from current directory into a project.""" + json_files = sorted(state.current_dir.glob('*.json')) + json_files = [f for f in json_files if f.name not in ( + '.editor_config.json', '.editor_snippets.json')] + + if not json_files: + ui.notify('No JSON files in current directory', type='warning') + return + + imported = 0 + skipped = 0 + for jf in json_files: + file_name = jf.stem + existing = state.db.get_data_file(project_id, file_name) + if existing: + skipped += 1 + continue + try: + state.db.import_json_file(project_id, jf) + imported += 1 + except Exception as e: + logger.warning(f"Failed to import {jf}: {e}") + + msg = f'Imported {imported} file(s)' + if skipped: + msg += f', skipped {skipped} existing' + ui.notify(msg, type='positive') + refresh_fn.refresh() diff --git a/tab_raw_ng.py b/tab_raw_ng.py index 39ec6f3..bdc4933 100644 --- a/tab_raw_ng.py +++ b/tab_raw_ng.py @@ -4,7 +4,7 @@ import json from nicegui import ui from state import AppState -from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY +from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY def render_raw_editor(state: AppState): @@ -52,6 +52,8 @@ def render_raw_editor(state: AppState): input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY] save_json(file_path, input_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, input_data) data.clear() data.update(input_data) diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index daaf460..d0467f4 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -5,7 +5,7 @@ from nicegui import ui from state import AppState from history_tree import HistoryTree -from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE +from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE def _delete_nodes(htree, data, file_path, node_ids): @@ -134,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn): def do_batch_delete(): current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) _delete_nodes(htree, data, file_path, current_valid) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.timeline_selected_nodes = set() ui.notify( f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', @@ -179,7 +181,7 @@ def _find_branch_for_node(htree, node_id): def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected): + selected, state=None): """Render branch-grouped node manager with restore, rename, delete, and preview.""" ui.label('Manage Version').classes('section-header') @@ -291,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ htree.nodes[sel_id]['note'] = rename_input.value data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Label updated', type='positive') refresh_fn() @@ -304,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ def delete_selected(): if sel_id in htree.nodes: _delete_nodes(htree, data, file_path, {sel_id}) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Node Deleted', type='positive') refresh_fn() @@ -377,7 +383,7 @@ def render_timeline_tab(state: AppState): _render_node_manager( all_nodes, htree, data, file_path, _restore_and_refresh, render_timeline.refresh, - selected) + selected, state=state) def _toggle_select(nid, checked): if checked: @@ -492,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState): htree.head_id = node['id'] data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) label = f"{node.get('note', 'Step')} ({node['id'][:4]})" state.restored_indicator = label ui.notify('Restored!', type='positive') diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..73427c1 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,286 @@ +import json +from pathlib import Path + +import pytest + +from db import ProjectDB +from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE + + +@pytest.fixture +def db(tmp_path): + """Create a fresh ProjectDB in a temp directory.""" + db_path = tmp_path / "test.db" + pdb = ProjectDB(db_path) + yield pdb + pdb.close() + + +# ------------------------------------------------------------------ +# Projects CRUD +# ------------------------------------------------------------------ + +class TestProjects: + def test_create_and_get(self, db): + pid = db.create_project("proj1", "/some/path", "A test project") + assert pid > 0 + proj = db.get_project("proj1") + assert proj is not None + assert proj["name"] == "proj1" + assert proj["folder_path"] == "/some/path" + assert proj["description"] == "A test project" + + def test_list_projects(self, db): + db.create_project("beta", "/b") + db.create_project("alpha", "/a") + projects = db.list_projects() + assert len(projects) == 2 + assert projects[0]["name"] == "alpha" + assert projects[1]["name"] == "beta" + + def test_get_nonexistent(self, db): + assert db.get_project("nope") is None + + def test_delete_project(self, db): + db.create_project("to_delete", "/x") + assert db.delete_project("to_delete") is True + assert db.get_project("to_delete") is None + + def test_delete_nonexistent(self, db): + assert db.delete_project("nope") is False + + def test_unique_name_constraint(self, db): + db.create_project("dup", "/a") + with pytest.raises(Exception): + db.create_project("dup", "/b") + + +# ------------------------------------------------------------------ +# Data files +# ------------------------------------------------------------------ + +class TestDataFiles: + def test_create_and_list(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"}) + assert df_id > 0 + files = db.list_data_files(pid) + assert len(files) == 1 + assert files[0]["name"] == "batch_i2v" + assert files[0]["data_type"] == "i2v" + + def test_get_data_file(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"}) + df = db.get_data_file(pid, "batch_i2v") + assert df is not None + assert df["top_level"] == {"key": "value"} + + def test_get_data_file_by_names(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + df = db.get_data_file_by_names("p1", "batch_i2v") + assert df is not None + assert df["name"] == "batch_i2v" + + def test_get_nonexistent_data_file(self, db): + pid = db.create_project("p1", "/p1") + assert db.get_data_file(pid, "nope") is None + + def test_unique_constraint(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + with pytest.raises(Exception): + db.create_data_file(pid, "batch_i2v", "vace") + + def test_cascade_delete(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "hello"}) + db.save_history_tree(df_id, {"nodes": {}}) + db.delete_project("p1") + assert db.get_data_file(pid, "batch_i2v") is None + + +# ------------------------------------------------------------------ +# Sequences +# ------------------------------------------------------------------ + +class TestSequences: + def test_upsert_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42}) + data = db.get_sequence(df_id, 1) + assert data == {"prompt": "hello", "seed": 42} + + def test_upsert_updates_existing(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "v1"}) + db.upsert_sequence(df_id, 1, {"prompt": "v2"}) + data = db.get_sequence(df_id, 1) + assert data["prompt"] == "v2" + + def test_list_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 3, {"a": 1}) + db.upsert_sequence(df_id, 1, {"b": 2}) + db.upsert_sequence(df_id, 2, {"c": 3}) + seqs = db.list_sequences(df_id) + assert seqs == [1, 2, 3] + + def test_get_nonexistent_sequence(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_sequence(df_id, 99) is None + + def test_get_sequence_keys(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, { + "prompt": "hello", + "seed": 42, + "cfg": 1.5, + "flag": True, + }) + keys, types = db.get_sequence_keys(df_id, 1) + assert "prompt" in keys + assert "seed" in keys + idx_prompt = keys.index("prompt") + idx_seed = keys.index("seed") + idx_cfg = keys.index("cfg") + idx_flag = keys.index("flag") + assert types[idx_prompt] == "STRING" + assert types[idx_seed] == "INT" + assert types[idx_cfg] == "FLOAT" + assert types[idx_flag] == "STRING" # bools -> STRING + + def test_get_sequence_keys_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + keys, types = db.get_sequence_keys(df_id, 99) + assert keys == [] + assert types == [] + + def test_delete_sequences_for_file(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.delete_sequences_for_file(df_id) + assert db.list_sequences(df_id) == [] + + +# ------------------------------------------------------------------ +# History trees +# ------------------------------------------------------------------ + +class TestHistoryTrees: + def test_save_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"} + db.save_history_tree(df_id, tree) + result = db.get_history_tree(df_id) + assert result == tree + + def test_upsert_updates(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.save_history_tree(df_id, {"v": 1}) + db.save_history_tree(df_id, {"v": 2}) + result = db.get_history_tree(df_id) + assert result == {"v": 2} + + def test_get_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_history_tree(df_id) is None + + +# ------------------------------------------------------------------ +# Import +# ------------------------------------------------------------------ + +class TestImport: + def test_import_json_file(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch_prompt_i2v.json" + data = { + KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "hello", "seed": 42}, + {"sequence_number": 2, "prompt": "world", "seed": 99}, + ], + KEY_HISTORY_TREE: {"nodes": {}, "head_id": None}, + } + json_path.write_text(json.dumps(data)) + + df_id = db.import_json_file(pid, json_path, "i2v") + assert df_id > 0 + + seqs = db.list_sequences(df_id) + assert seqs == [1, 2] + + s1 = db.get_sequence(df_id, 1) + assert s1["prompt"] == "hello" + assert s1["seed"] == 42 + + tree = db.get_history_tree(df_id) + assert tree == {"nodes": {}, "head_id": None} + + def test_import_file_name_from_stem(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "my_batch.json" + json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]})) + db.import_json_file(pid, json_path) + df = db.get_data_file(pid, "my_batch") + assert df is not None + + def test_import_no_batch_data(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "simple.json" + json_path.write_text(json.dumps({"prompt": "flat file"})) + df_id = db.import_json_file(pid, json_path) + seqs = db.list_sequences(df_id) + assert seqs == [] + + +# ------------------------------------------------------------------ +# Query helpers +# ------------------------------------------------------------------ + +class TestQueryHelpers: + def test_query_sequence_data(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + result = db.query_sequence_data("myproject", "batch_i2v", 1) + assert result == {"prompt": "test", "seed": 7} + + def test_query_sequence_data_not_found(self, db): + assert db.query_sequence_data("nope", "nope", 1) is None + + def test_query_sequence_keys(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1) + assert "prompt" in keys + assert "seed" in keys + + def test_list_project_files(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "file_a", "i2v") + db.create_data_file(pid, "file_b", "vace") + files = db.list_project_files("p1") + assert len(files) == 2 + + def test_list_project_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {}) + db.upsert_sequence(df_id, 2, {}) + seqs = db.list_project_sequences("p1", "batch") + assert seqs == [1, 2] diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py new file mode 100644 index 0000000..00a59a7 --- /dev/null +++ b/tests/test_project_loader.py @@ -0,0 +1,201 @@ +import json +from unittest.mock import patch, MagicMock +from io import BytesIO + +import pytest + +from project_loader import ( + ProjectLoaderDynamic, + ProjectLoaderStandard, + ProjectLoaderVACE, + ProjectLoaderLoRA, + _fetch_json, + _fetch_data, + _fetch_keys, + MAX_DYNAMIC_OUTPUTS, +) + + +def _mock_urlopen(data: dict): + """Create a mock context manager for urllib.request.urlopen.""" + response = MagicMock() + response.read.return_value = json.dumps(data).encode() + response.__enter__ = lambda s: s + response.__exit__ = MagicMock(return_value=False) + return response + + +class TestFetchHelpers: + def test_fetch_json_success(self): + data = {"key": "value"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)): + result = _fetch_json("http://example.com/api") + assert result == data + + def test_fetch_json_failure(self): + import urllib.error + with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): + result = _fetch_json("http://example.com/api") + assert result == {} + + def test_fetch_data_builds_url(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url + + def test_fetch_keys_builds_url(self): + data = {"keys": ["prompt"], "types": ["STRING"]} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url + + def test_fetch_data_strips_trailing_slash(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080/", "proj1", "file1", 1) + called_url = mock.call_args[0][0] + assert "//api" not in called_url + + +class TestProjectLoaderDynamic: + def test_load_dynamic_with_keys(self): + data = {"prompt": "hello", "seed": 42, "cfg": 1.5} + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="prompt,seed,cfg" + ) + assert result[0] == "hello" + assert result[1] == 42 + assert result[2] == 1.5 + assert len(result) == MAX_DYNAMIC_OUTPUTS + + def test_load_dynamic_empty_keys(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + assert all(v == "" for v in result) + + def test_load_dynamic_missing_key(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="nonexistent" + ) + assert result[0] == "" + + def test_load_dynamic_bool_becomes_string(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"flag": True}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="flag" + ) + assert result[0] == "true" + + def test_input_types_has_manager_url(self): + inputs = ProjectLoaderDynamic.INPUT_TYPES() + assert "manager_url" in inputs["required"] + assert "project_name" in inputs["required"] + assert "file_name" in inputs["required"] + assert "sequence_number" in inputs["required"] + + def test_category(self): + assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" + + +class TestProjectLoaderStandard: + def test_load_standard(self): + data = { + "general_prompt": "hello", + "general_negative": "bad", + "current_prompt": "specific", + "negative": "neg", + "camera": "pan", + "flf": 0.5, + "seed": 42, + "video file path": "/v.mp4", + "reference image path": "/r.png", + "flf image path": "/f.png", + } + node = ProjectLoaderStandard() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) + assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png") + + def test_load_standard_defaults(self): + node = ProjectLoaderStandard() + with patch("project_loader._fetch_data", return_value={}): + result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) + assert result[0] == "" # general_prompt + assert result[5] == 0.0 # flf + assert result[6] == 0 # seed + + +class TestProjectLoaderVACE: + def test_load_vace(self): + data = { + "general_prompt": "hello", + "general_negative": "bad", + "current_prompt": "specific", + "negative": "neg", + "camera": "pan", + "flf": 0.5, + "seed": 42, + "frame_to_skip": 81, + "input_a_frames": 16, + "input_b_frames": 16, + "reference path": "/ref", + "reference switch": 1, + "vace schedule": 2, + "video file path": "/v.mp4", + "reference image path": "/r.png", + } + node = ProjectLoaderVACE() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_vace("http://localhost:8080", "proj1", "batch", 1) + assert result[7] == 81 # frame_to_skip + assert result[12] == 2 # vace_schedule + + +class TestProjectLoaderLoRA: + def test_load_loras(self): + data = { + "lora 1 high": "", + "lora 1 low": "", + "lora 2 high": "", + "lora 2 low": "", + "lora 3 high": "", + "lora 3 low": "", + } + node = ProjectLoaderLoRA() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) + assert result[0] == "" + assert result[1] == "" + + def test_load_loras_empty(self): + node = ProjectLoaderLoRA() + with patch("project_loader._fetch_data", return_value={}): + result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) + assert all(v == "" for v in result) + + +class TestNodeMappings: + def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 diff --git a/utils.py b/utils.py index 2e49007..ea5cdc9 100644 --- a/utils.py +++ b/utils.py @@ -160,6 +160,43 @@ def get_file_mtime(path: str | Path) -> float: return path.stat().st_mtime return 0 +def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: + """Dual-write helper: sync JSON data to the project database. + + Resolves (or creates) the data_file, upserts all sequences from batch_data, + and saves the history_tree. + """ + if not db or not project_name: + return + try: + proj = db.get_project(project_name) + if not proj: + return + file_name = Path(file_path).stem + df = db.get_data_file(proj["id"], file_name) + if not df: + top_level = {k: v for k, v in data.items() + if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + df_id = db.create_data_file(proj["id"], file_name, "generic", top_level) + else: + df_id = df["id"] + + # Sync sequences + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + db.delete_sequences_for_file(df_id) + for item in batch_data: + seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) + db.upsert_sequence(df_id, seq_num, item) + + # Sync history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + db.save_history_tree(df_id, history_tree) + except Exception as e: + logger.warning(f"sync_to_db failed: {e}") + + def generate_templates(current_dir: Path) -> None: """Creates batch template files if folder is empty.""" first = DEFAULTS.copy() diff --git a/web/project_dynamic.js b/web/project_dynamic.js new file mode 100644 index 0000000..9d58b78 --- /dev/null +++ b/web/project_dynamic.js @@ -0,0 +1,139 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.dynamic", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectLoaderDynamic") return; + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + + // Hide internal widgets (managed by JS) + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + // Remove all 32 default outputs from Python RETURN_TYPES + while (this.outputs.length > 0) { + this.removeOutput(0); + } + + // Add Refresh button + this.addWidget("button", "Refresh Outputs", null, () => { + this.refreshDynamicOutputs(); + }); + + this.setSize(this.computeSize()); + }; + + nodeType.prototype.refreshDynamicOutputs = async function () { + const urlWidget = this.widgets?.find(w => w.name === "manager_url"); + const projectWidget = this.widgets?.find(w => w.name === "project_name"); + const fileWidget = this.widgets?.find(w => w.name === "file_name"); + const seqWidget = this.widgets?.find(w => w.name === "sequence_number"); + + if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` + ); + const { keys, types } = await resp.json(); + + // Store keys and types in hidden widgets for persistence + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + if (okWidget) okWidget.value = keys.join(","); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + if (otWidget) otWidget.value = types.join(","); + + // Build a map of current output names to slot indices + const oldSlots = {}; + for (let i = 0; i < this.outputs.length; i++) { + oldSlots[this.outputs[i].name] = i; + } + + // Build new outputs, reusing existing slots to preserve links + const newOutputs = []; + for (let k = 0; k < keys.length; k++) { + const key = keys[k]; + const type = types[k] || "*"; + if (key in oldSlots) { + const slot = this.outputs[oldSlots[key]]; + slot.type = type; + newOutputs.push(slot); + delete oldSlots[key]; + } else { + newOutputs.push({ name: key, type: type, links: null }); + } + } + + // Disconnect links on slots that are being removed + for (const name in oldSlots) { + const idx = oldSlots[name]; + if (this.outputs[idx]?.links?.length) { + for (const linkId of [...this.outputs[idx].links]) { + this.graph?.removeLink(linkId); + } + } + } + + // Reassign the outputs array and fix link slot indices + this.outputs = newOutputs; + if (this.graph) { + for (let i = 0; i < this.outputs.length; i++) { + const links = this.outputs[i].links; + if (!links) continue; + for (const linkId of links) { + const link = this.graph.links[linkId]; + if (link) link.origin_slot = i; + } + } + } + + this.setSize(this.computeSize()); + app.graph.setDirtyCanvas(true, true); + } catch (e) { + console.error("[ProjectLoaderDynamic] Refresh failed:", e); + } + }; + + // Restore state on workflow load + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + + // Hide internal widgets + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + + const keys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + const types = otWidget?.value + ? otWidget.value.split(",") + : []; + + // Rename and set types to match stored state (preserves links) + for (let i = 0; i < this.outputs.length && i < keys.length; i++) { + this.outputs[i].name = keys[i].trim(); + if (types[i]) this.outputs[i].type = types[i]; + } + + // Remove any extra outputs beyond the key count + while (this.outputs.length > keys.length) { + this.removeOutput(this.outputs.length - 1); + } + + this.setSize(this.computeSize()); + }; + }, +}); From 6b7e9ea68234b2358ad3d3572f6653b077bd0a9a Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:15:56 +0100 Subject: [PATCH 02/14] Port dynamic node improvements from ComfyUI-JSON-Dynamic - Deferred output cleanup (_configured flag + queueMicrotask) to prevent breaking links when other nodes (e.g. Kijai Set/Get) resolve outputs during graph loading - file_not_found error handling in refresh to keep existing outputs intact - Fallback widget sync in onConfigure when widget values are empty but serialized outputs exist Applied to both json_dynamic.js and project_dynamic.js. Co-Authored-By: Claude Opus 4.6 --- json_loader.py | 2 ++ web/json_dynamic.js | 60 +++++++++++++++++++++++++++++++----------- web/project_dynamic.js | 58 ++++++++++++++++++++++++++++++---------- 3 files changed, 90 insertions(+), 30 deletions(-) diff --git a/json_loader.py b/json_loader.py index eed69fb..d780a52 100644 --- a/json_loader.py +++ b/json_loader.py @@ -75,6 +75,8 @@ if PromptServer is not None: except (ValueError, TypeError): seq = 1 data = read_json_data(json_path) + if not data: + return web.json_response({"keys": [], "types": [], "error": "file_not_found"}) target = get_batch_item(data, seq) keys = [] types = [] diff --git a/web/json_dynamic.js b/web/json_dynamic.js index 81e11f1..22ca527 100644 --- a/web/json_dynamic.js +++ b/web/json_dynamic.js @@ -17,17 +17,31 @@ app.registerExtension({ if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } } - // Remove all 32 default outputs from Python RETURN_TYPES - while (this.outputs.length > 0) { - this.removeOutput(0); - } + // Do NOT remove default outputs synchronously here. + // During graph loading, ComfyUI creates all nodes (firing onNodeCreated) + // before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve + // links to our outputs during their configure step. If we remove outputs + // here, those nodes find no output slot and error out. + // + // Instead, defer cleanup: for loaded workflows onConfigure sets _configured + // before this runs; for new nodes the defaults are cleaned up. + this._configured = false; // Add Refresh button this.addWidget("button", "Refresh Outputs", null, () => { this.refreshDynamicOutputs(); }); - this.setSize(this.computeSize()); + queueMicrotask(() => { + if (!this._configured) { + // New node (not loading) — remove the 32 Python default outputs + while (this.outputs.length > 0) { + this.removeOutput(0); + } + this.setSize(this.computeSize()); + app.graph?.setDirtyCanvas(true, true); + } + }); }; nodeType.prototype.refreshDynamicOutputs = async function () { @@ -39,7 +53,14 @@ app.registerExtension({ const resp = await api.fetchApi( `/json_manager/get_keys?path=${encodeURIComponent(pathWidget.value)}&sequence_number=${seqWidget?.value || 1}` ); - const { keys, types } = await resp.json(); + const data = await resp.json(); + const { keys, types } = data; + + // If the file wasn't found, keep existing outputs and links intact + if (data.error === "file_not_found") { + console.warn("[JSONLoaderDynamic] File not found, keeping existing outputs:", pathWidget.value); + return; + } // Store keys and types in hidden widgets for persistence const okWidget = this.widgets?.find(w => w.name === "output_keys"); @@ -82,7 +103,6 @@ app.registerExtension({ // Reassign the outputs array and fix link slot indices this.outputs = newOutputs; - // Update link origin_slot to match new positions if (this.graph) { for (let i = 0; i < this.outputs.length; i++) { const links = this.outputs[i].links; @@ -105,6 +125,7 @@ app.registerExtension({ const origOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function (info) { origOnConfigure?.apply(this, arguments); + this._configured = true; // Hide internal widgets for (const name of ["output_keys", "output_types"]) { @@ -122,16 +143,23 @@ app.registerExtension({ ? otWidget.value.split(",") : []; - // On load, LiteGraph already restored serialized outputs with links. - // Rename and set types to match stored state (preserves links). - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i].trim(); - if (types[i]) this.outputs[i].type = types[i]; - } + if (keys.length > 0) { + // On load, LiteGraph already restored serialized outputs with links. + // Rename and set types to match stored state (preserves links). + for (let i = 0; i < this.outputs.length && i < keys.length; i++) { + this.outputs[i].name = keys[i].trim(); + if (types[i]) this.outputs[i].type = types[i]; + } - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { - this.removeOutput(this.outputs.length - 1); + // Remove any extra outputs beyond the key count + while (this.outputs.length > keys.length) { + this.removeOutput(this.outputs.length - 1); + } + } else if (this.outputs.length > 0) { + // Widget values empty but serialized outputs exist — sync widgets + // from the outputs LiteGraph already restored (fallback). + if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(","); } this.setSize(this.computeSize()); diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 9d58b78..9346f3d 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -17,17 +17,31 @@ app.registerExtension({ if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } } - // Remove all 32 default outputs from Python RETURN_TYPES - while (this.outputs.length > 0) { - this.removeOutput(0); - } + // Do NOT remove default outputs synchronously here. + // During graph loading, ComfyUI creates all nodes (firing onNodeCreated) + // before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve + // links to our outputs during their configure step. If we remove outputs + // here, those nodes find no output slot and error out. + // + // Instead, defer cleanup: for loaded workflows onConfigure sets _configured + // before this runs; for new nodes the defaults are cleaned up. + this._configured = false; // Add Refresh button this.addWidget("button", "Refresh Outputs", null, () => { this.refreshDynamicOutputs(); }); - this.setSize(this.computeSize()); + queueMicrotask(() => { + if (!this._configured) { + // New node (not loading) — remove the 32 Python default outputs + while (this.outputs.length > 0) { + this.removeOutput(0); + } + this.setSize(this.computeSize()); + app.graph?.setDirtyCanvas(true, true); + } + }); }; nodeType.prototype.refreshDynamicOutputs = async function () { @@ -42,7 +56,14 @@ app.registerExtension({ const resp = await api.fetchApi( `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` ); - const { keys, types } = await resp.json(); + const data = await resp.json(); + const { keys, types } = data; + + // If the API returned an error, keep existing outputs and links intact + if (data.error) { + console.warn("[ProjectLoaderDynamic] API error, keeping existing outputs:", data.error); + return; + } // Store keys and types in hidden widgets for persistence const okWidget = this.widgets?.find(w => w.name === "output_keys"); @@ -105,6 +126,7 @@ app.registerExtension({ const origOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function (info) { origOnConfigure?.apply(this, arguments); + this._configured = true; // Hide internal widgets for (const name of ["output_keys", "output_types"]) { @@ -122,15 +144,23 @@ app.registerExtension({ ? otWidget.value.split(",") : []; - // Rename and set types to match stored state (preserves links) - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i].trim(); - if (types[i]) this.outputs[i].type = types[i]; - } + if (keys.length > 0) { + // On load, LiteGraph already restored serialized outputs with links. + // Rename and set types to match stored state (preserves links). + for (let i = 0; i < this.outputs.length && i < keys.length; i++) { + this.outputs[i].name = keys[i].trim(); + if (types[i]) this.outputs[i].type = types[i]; + } - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { - this.removeOutput(this.outputs.length - 1); + // Remove any extra outputs beyond the key count + while (this.outputs.length > keys.length) { + this.removeOutput(this.outputs.length - 1); + } + } else if (this.outputs.length > 0) { + // Widget values empty but serialized outputs exist — sync widgets + // from the outputs LiteGraph already restored (fallback). + if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(","); } this.setSize(this.computeSize()); From ba8f104bc1310a4d67f8221d380de88a1875ba40 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:25:31 +0100 Subject: [PATCH 03/14] Fix 6 bugs found during code review - Fix NameError: pass state to _render_vace_settings (tab_batch_ng.py) - Fix non-atomic sync_to_db: use BEGIN IMMEDIATE transaction with rollback - Fix create_secondary() missing db/current_project/db_enabled fields - Fix URL encoding: percent-encode project/file names in API URLs - Fix import_json_file crash on re-import: upsert instead of insert - Fix dual DB instances: share single ProjectDB between UI and API routes - Also fixes top_level metadata never being updated on existing data_files Co-Authored-By: Claude Opus 4.6 --- db.py | 21 ++++++++++-- main.py | 24 ++++++------- project_loader.py | 16 ++++++--- state.py | 3 ++ tab_batch_ng.py | 4 +-- tests/test_db.py | 26 ++++++++++++++ tests/test_project_loader.py | 10 ++++++ utils.py | 66 ++++++++++++++++++++++++++---------- 8 files changed, 131 insertions(+), 39 deletions(-) diff --git a/db.py b/db.py index 11efc7d..b9b17a4 100644 --- a/db.py +++ b/db.py @@ -228,15 +228,30 @@ class ProjectDB: # ------------------------------------------------------------------ def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: - """Import a JSON file into the database, splitting batch_data into sequences.""" + """Import a JSON file into the database, splitting batch_data into sequences. + + Safe to call repeatedly — existing data_file is updated, sequences are + replaced, and history_tree is upserted. + """ json_path = Path(json_path) data, _ = load_json(json_path) file_name = json_path.stem - # Extract top-level keys that aren't batch_data or history_tree top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - df_id = self.create_data_file(project_id, file_name, data_type, top_level) + existing = self.get_data_file(project_id, file_name) + if existing: + df_id = existing["id"] + now = time.time() + self.conn.execute( + "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", + (data_type, json.dumps(top_level), now, df_id), + ) + self.conn.commit() + # Clear old sequences before re-importing + self.delete_sequences_for_file(df_id) + else: + df_id = self.create_data_file(project_id, file_name, data_type, top_level) # Import sequences from batch_data batch_data = data.get(KEY_BATCH_DATA, []) diff --git a/main.py b/main.py index 5936aa8..efa2619 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,13 @@ from api_routes import register_api_routes logger = logging.getLogger(__name__) +# Single shared DB instance for both the UI and API routes +_shared_db: ProjectDB | None = None +try: + _shared_db = ProjectDB() +except Exception as e: + logger.warning(f"Failed to initialize ProjectDB: {e}") + @ui.page('/') def index(): @@ -166,12 +173,8 @@ def index(): current_project=config.get('current_project', ''), ) - # Initialize project database - try: - state.db = ProjectDB() - except Exception as e: - logger.warning(f"Failed to initialize ProjectDB: {e}") - state.db = None + # Use the shared DB instance + state.db = _shared_db dual_pane = {'active': False, 'state': None} @@ -500,11 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict): ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) -# Register REST API routes for ComfyUI connectivity -try: - _api_db = ProjectDB() - register_api_routes(_api_db) -except Exception as e: - logger.warning(f"Failed to register API routes: {e}") +# Register REST API routes for ComfyUI connectivity (uses the shared DB instance) +if _shared_db is not None: + register_api_routes(_shared_db) ui.run(title='AI Settings Manager', port=8080, reload=True) diff --git a/project_loader.py b/project_loader.py index 2adc6c3..1634135 100644 --- a/project_loader.py +++ b/project_loader.py @@ -1,5 +1,6 @@ import json import logging +import urllib.parse import urllib.request import urllib.error from typing import Any @@ -49,13 +50,17 @@ def _fetch_json(url: str) -> dict: def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: """Fetch sequence data from the NiceGUI REST API.""" - url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}" return _fetch_json(url) def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: """Fetch keys/types from the NiceGUI REST API.""" - url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}" return _fetch_json(url) @@ -71,7 +76,7 @@ if PromptServer is not None: @PromptServer.instance.routes.get("/json_manager/list_project_files") async def list_project_files_proxy(request): manager_url = request.query.get("url", "http://localhost:8080") - project = request.query.get("project", "") + project = urllib.parse.quote(request.query.get("project", ""), safe='') url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" data = _fetch_json(url) return web.json_response(data) @@ -79,8 +84,8 @@ if PromptServer is not None: @PromptServer.instance.routes.get("/json_manager/list_project_sequences") async def list_project_sequences_proxy(request): manager_url = request.query.get("url", "http://localhost:8080") - project = request.query.get("project", "") - file_name = request.query.get("file", "") + project = urllib.parse.quote(request.query.get("project", ""), safe='') + file_name = urllib.parse.quote(request.query.get("file", ""), safe='') url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" data = _fetch_json(url) return web.json_response(data) @@ -98,6 +103,7 @@ if PromptServer is not None: return web.json_response(data) + # ========================================== # 0. DYNAMIC NODE (Project-based) # ========================================== diff --git a/state.py b/state.py index 891a14e..bef8818 100644 --- a/state.py +++ b/state.py @@ -34,4 +34,7 @@ class AppState: config=self.config, current_dir=self.current_dir, snippets=self.snippets, + db=self.db, + current_project=self.current_project, + db_enabled=self.db_enabled, ) diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 36abee8..7ec49ae 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -457,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # --- VACE Settings (full width) --- with ui.expansion('VACE Settings', icon='settings').classes('w-full'): - _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list) + _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list) # --- LoRA Settings --- with ui.expansion('LoRA Settings', icon='style').classes('w-full'): @@ -539,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # VACE Settings sub-section # ====================================================================== -def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): +def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list): # VACE Schedule (needed early for both columns) sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1)) diff --git a/tests/test_db.py b/tests/test_db.py index 73427c1..a0dc7ab 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -246,6 +246,32 @@ class TestImport: seqs = db.list_sequences(df_id) assert seqs == [] + def test_reimport_updates_existing(self, db, tmp_path): + """Re-importing the same file should update data, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + + # First import + data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]} + json_path.write_text(json.dumps(data_v1)) + df_id_1 = db.import_json_file(pid, json_path, "i2v") + + # Second import (same file, updated data) + data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]} + json_path.write_text(json.dumps(data_v2)) + df_id_2 = db.import_json_file(pid, json_path, "vace") + + # Should reuse the same data_file row + assert df_id_1 == df_id_2 + # Data type should be updated + df = db.get_data_file(pid, "batch") + assert df["data_type"] == "vace" + # Sequences should reflect v2 + seqs = db.list_sequences(df_id_2) + assert seqs == [1, 2] + s1 = db.get_sequence(df_id_2, 1) + assert s1["prompt"] == "v2" + # ------------------------------------------------------------------ # Query helpers diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 00a59a7..fd77eb9 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -61,6 +61,16 @@ class TestFetchHelpers: called_url = mock.call_args[0][0] assert "//api" not in called_url + def test_fetch_data_encodes_special_chars(self): + """Project/file names with spaces or special chars should be percent-encoded.""" + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080", "my project", "batch file", 1) + called_url = mock.call_args[0][0] + assert "my%20project" in called_url + assert "batch%20file" in called_url + assert " " not in called_url.split("?")[0] # no raw spaces in path + class TestProjectLoaderDynamic: def test_load_dynamic_with_keys(self): diff --git a/utils.py b/utils.py index ea5cdc9..44707a2 100644 --- a/utils.py +++ b/utils.py @@ -164,7 +164,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: """Dual-write helper: sync JSON data to the project database. Resolves (or creates) the data_file, upserts all sequences from batch_data, - and saves the history_tree. + and saves the history_tree. All writes happen in a single transaction. """ if not db or not project_name: return @@ -173,26 +173,58 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: if not proj: return file_name = Path(file_path).stem - df = db.get_data_file(proj["id"], file_name) - if not df: + + # Use a single transaction for atomicity + db.conn.execute("BEGIN IMMEDIATE") + try: + df = db.get_data_file(proj["id"], file_name) top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - df_id = db.create_data_file(proj["id"], file_name, "generic", top_level) - else: - df_id = df["id"] + if not df: + now = __import__('time').time() + cur = db.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (proj["id"], file_name, "generic", json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid + else: + df_id = df["id"] + # Update top_level metadata + now = __import__('time').time() + db.conn.execute( + "UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?", + (json.dumps(top_level), now, df_id), + ) - # Sync sequences - batch_data = data.get(KEY_BATCH_DATA, []) - if isinstance(batch_data, list): - db.delete_sequences_for_file(df_id) - for item in batch_data: - seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) - db.upsert_sequence(df_id, seq_num, item) + # Sync sequences + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + for item in batch_data: + seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) + now = __import__('time').time() + db.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?)", + (df_id, seq_num, json.dumps(item), now), + ) - # Sync history tree - history_tree = data.get(KEY_HISTORY_TREE) - if history_tree and isinstance(history_tree, dict): - db.save_history_tree(df_id, history_tree) + # Sync history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = __import__('time').time() + db.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + db.conn.commit() + except Exception: + db.conn.rollback() + raise except Exception as e: logger.warning(f"sync_to_db failed: {e}") From b499eb4dfd71c3277f19250d8f58586a47a328d8 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:32:35 +0100 Subject: [PATCH 04/14] Fix 8 bugs from second code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HIGH: - Fix JS TypeError on empty API response: validate keys/types are arrays before using them; add HTTP status check (resp.ok) - Fix BEGIN IMMEDIATE conflict: set isolation_level=None (autocommit) on SQLite connection so explicit transactions work without implicit ones MEDIUM: - Fix import_json_file non-atomic: wrap entire operation in BEGIN/COMMIT with ROLLBACK on error — no more partial imports - Fix crash on non-dict batch_data items: skip non-dict elements - Fix comma-in-key corruption: store keys/types as JSON arrays in hidden widgets instead of comma-delimited strings (backward-compat fallback) - Fix blocking I/O in API routes: change async def to def so FastAPI auto-threads the synchronous SQLite calls LOW: - Fix missing ?. on app.graph.setDirtyCanvas in refreshDynamicOutputs Co-Authored-By: Claude Opus 4.6 --- api_routes.py | 10 ++--- db.py | 85 ++++++++++++++++++++++++------------ project_loader.py | 8 +++- tests/test_db.py | 38 ++++++++++++++++ tests/test_project_loader.py | 14 ++++++ utils.py | 4 +- web/project_dynamic.js | 48 ++++++++++++-------- 7 files changed, 155 insertions(+), 52 deletions(-) diff --git a/api_routes.py b/api_routes.py index 36e84cd..6d5b42e 100644 --- a/api_routes.py +++ b/api_routes.py @@ -35,25 +35,25 @@ def _get_db() -> ProjectDB: return _db -async def _list_projects() -> dict[str, Any]: +def _list_projects() -> dict[str, Any]: db = _get_db() projects = db.list_projects() return {"projects": [p["name"] for p in projects]} -async def _list_files(name: str) -> dict[str, Any]: +def _list_files(name: str) -> dict[str, Any]: db = _get_db() files = db.list_project_files(name) return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]} -async def _list_sequences(name: str, file_name: str) -> dict[str, Any]: +def _list_sequences(name: str, file_name: str) -> dict[str, Any]: db = _get_db() seqs = db.list_project_sequences(name, file_name) return {"sequences": seqs} -async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: +def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() data = db.query_sequence_data(name, file_name, seq) if data is None: @@ -61,7 +61,7 @@ async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> d return data -async def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: +def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() keys, types = db.query_sequence_keys(name, file_name, seq) return {"keys": keys, "types": types} diff --git a/db.py b/db.py index b9b17a4..000f577 100644 --- a/db.py +++ b/db.py @@ -56,12 +56,15 @@ class ProjectDB: def __init__(self, db_path: str | Path | None = None): self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self.conn = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + isolation_level=None, # autocommit — explicit BEGIN/COMMIT only + ) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA journal_mode=WAL") self.conn.execute("PRAGMA foreign_keys=ON") self.conn.executescript(SCHEMA_SQL) - self.conn.commit() def close(self): self.conn.close() @@ -231,7 +234,7 @@ class ProjectDB: """Import a JSON file into the database, splitting batch_data into sequences. Safe to call repeatedly — existing data_file is updated, sequences are - replaced, and history_tree is upserted. + replaced, and history_tree is upserted. Atomic: all-or-nothing. """ json_path = Path(json_path) data, _ = load_json(json_path) @@ -239,33 +242,61 @@ class ProjectDB: top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - existing = self.get_data_file(project_id, file_name) - if existing: - df_id = existing["id"] - now = time.time() - self.conn.execute( - "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", - (data_type, json.dumps(top_level), now, df_id), - ) - self.conn.commit() - # Clear old sequences before re-importing - self.delete_sequences_for_file(df_id) - else: - df_id = self.create_data_file(project_id, file_name, data_type, top_level) + self.conn.execute("BEGIN IMMEDIATE") + try: + existing = self.conn.execute( + "SELECT id FROM data_files WHERE project_id = ? AND name = ?", + (project_id, file_name), + ).fetchone() - # Import sequences from batch_data - batch_data = data.get(KEY_BATCH_DATA, []) - if isinstance(batch_data, list): - for item in batch_data: - seq_num = int(item.get("sequence_number", 0)) - self.upsert_sequence(df_id, seq_num, item) + if existing: + df_id = existing["id"] + now = time.time() + self.conn.execute( + "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", + (data_type, json.dumps(top_level), now, df_id), + ) + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + else: + now = time.time() + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, file_name, data_type, json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid - # Import history tree - history_tree = data.get(KEY_HISTORY_TREE) - if history_tree and isinstance(history_tree, dict): - self.save_history_tree(df_id, history_tree) + # Import sequences from batch_data + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + for item in batch_data: + if not isinstance(item, dict): + continue + seq_num = int(item.get("sequence_number", 0)) + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (df_id, seq_num, json.dumps(item), now), + ) - return df_id + # Import history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + self.conn.execute("COMMIT") + return df_id + except Exception: + self.conn.execute("ROLLBACK") + raise # ------------------------------------------------------------------ # Query helpers (for REST API) diff --git a/project_loader.py b/project_loader.py index 1634135..c7a2cc9 100644 --- a/project_loader.py +++ b/project_loader.py @@ -134,7 +134,13 @@ class ProjectLoaderDynamic: output_keys="", output_types=""): data = _fetch_data(manager_url, project_name, file_name, sequence_number) - keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] + # Parse keys — try JSON array first, fall back to comma-split for compat + keys = [] + if output_keys: + try: + keys = json.loads(output_keys) + except (json.JSONDecodeError, TypeError): + keys = [k.strip() for k in output_keys.split(",") if k.strip()] results = [] for key in keys: diff --git a/tests/test_db.py b/tests/test_db.py index a0dc7ab..341edcb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -272,6 +272,44 @@ class TestImport: s1 = db.get_sequence(df_id_2, 1) assert s1["prompt"] == "v2" + def test_import_skips_non_dict_batch_items(self, db, tmp_path): + """Non-dict elements in batch_data should be silently skipped, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "mixed.json" + data = {KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "valid"}, + "not a dict", + 42, + None, + {"sequence_number": 3, "prompt": "also valid"}, + ]} + json_path.write_text(json.dumps(data)) + df_id = db.import_json_file(pid, json_path) + + seqs = db.list_sequences(df_id) + assert seqs == [1, 3] + + def test_import_atomic_on_error(self, db, tmp_path): + """If import fails partway, no partial data should be committed.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]} + json_path.write_text(json.dumps(data)) + db.import_json_file(pid, json_path) + + # Now try to import with bad data that will cause an error + # (overwrite the file with invalid sequence_number that causes int() to fail) + bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]} + json_path.write_text(json.dumps(bad_data)) + with pytest.raises(ValueError): + db.import_json_file(pid, json_path) + + # Original data should still be intact (rollback worked) + df = db.get_data_file(pid, "batch") + assert df is not None + s1 = db.get_sequence(df["id"], 1) + assert s1["prompt"] == "hello" + # ------------------------------------------------------------------ # Query helpers diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index fd77eb9..41399ca 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -86,6 +86,20 @@ class TestProjectLoaderDynamic: assert result[2] == 1.5 assert len(result) == MAX_DYNAMIC_OUTPUTS + def test_load_dynamic_with_json_encoded_keys(self): + """JSON-encoded output_keys should be parsed correctly.""" + import json as _json + data = {"my,key": "comma_val", "normal": "ok"} + node = ProjectLoaderDynamic() + keys_json = _json.dumps(["my,key", "normal"]) + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json + ) + assert result[0] == "comma_val" + assert result[1] == "ok" + def test_load_dynamic_empty_keys(self): node = ProjectLoaderDynamic() with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): diff --git a/utils.py b/utils.py index 44707a2..809d58f 100644 --- a/utils.py +++ b/utils.py @@ -221,9 +221,9 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: (df_id, json.dumps(history_tree), now), ) - db.conn.commit() + db.conn.execute("COMMIT") except Exception: - db.conn.rollback() + db.conn.execute("ROLLBACK") raise except Exception as e: logger.warning(f"sync_to_db failed: {e}") diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 9346f3d..ed830a8 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -56,20 +56,27 @@ app.registerExtension({ const resp = await api.fetchApi( `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` ); - const data = await resp.json(); - const { keys, types } = data; - // If the API returned an error, keep existing outputs and links intact - if (data.error) { - console.warn("[ProjectLoaderDynamic] API error, keeping existing outputs:", data.error); + if (!resp.ok) { + console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs"); return; } - // Store keys and types in hidden widgets for persistence + const data = await resp.json(); + const keys = data.keys; + const types = data.types; + + // If the API returned an error or missing data, keep existing outputs and links intact + if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { + console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types"); + return; + } + + // Store keys and types in hidden widgets for persistence (JSON-encoded) const okWidget = this.widgets?.find(w => w.name === "output_keys"); - if (okWidget) okWidget.value = keys.join(","); + if (okWidget) okWidget.value = JSON.stringify(keys); const otWidget = this.widgets?.find(w => w.name === "output_types"); - if (otWidget) otWidget.value = types.join(","); + if (otWidget) otWidget.value = JSON.stringify(types); // Build a map of current output names to slot indices const oldSlots = {}; @@ -116,7 +123,7 @@ app.registerExtension({ } this.setSize(this.computeSize()); - app.graph.setDirtyCanvas(true, true); + app.graph?.setDirtyCanvas(true, true); } catch (e) { console.error("[ProjectLoaderDynamic] Refresh failed:", e); } @@ -137,12 +144,19 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - const keys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - const types = otWidget?.value - ? otWidget.value.split(",") - : []; + // Parse keys/types — try JSON array first, fall back to comma-split + let keys = []; + if (okWidget?.value) { + try { keys = JSON.parse(okWidget.value); } catch (_) { + keys = okWidget.value.split(",").map(k => k.trim()).filter(Boolean); + } + } + let types = []; + if (otWidget?.value) { + try { types = JSON.parse(otWidget.value); } catch (_) { + types = otWidget.value.split(","); + } + } if (keys.length > 0) { // On load, LiteGraph already restored serialized outputs with links. @@ -159,8 +173,8 @@ app.registerExtension({ } else if (this.outputs.length > 0) { // Widget values empty but serialized outputs exist — sync widgets // from the outputs LiteGraph already restored (fallback). - if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(","); - if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(","); + if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name)); + if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type)); } this.setSize(this.computeSize()); From c4d107206fc638e4d66e21e3df8e4d5c9ed04b28 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:38:37 +0100 Subject: [PATCH 05/14] Fix 4 bugs from third code review - Fix delete_proj not persisting cleared current_project to config: page reload after deleting active project restored deleted name, silently breaking all DB sync - Fix sync_to_db crash on non-dict batch_data items: add isinstance guard matching import_json_file - Fix output_types ignored in load_dynamic: parse declared types and use to_int()/to_float() to coerce values, so downstream ComfyUI nodes receive correct types even when API returns strings - Fix backward-compat comma-split for types not trimming whitespace: legacy workflows with "STRING, INT" got types " INT" breaking ComfyUI connection type-matching Co-Authored-By: Claude Opus 4.6 --- project_loader.py | 18 ++++++++++++++++-- tab_projects_ng.py | 4 ++++ tests/test_project_loader.py | 16 ++++++++++++++++ utils.py | 2 ++ web/project_dynamic.js | 4 ++-- 5 files changed, 40 insertions(+), 4 deletions(-) diff --git a/project_loader.py b/project_loader.py index c7a2cc9..cc52d2c 100644 --- a/project_loader.py +++ b/project_loader.py @@ -142,10 +142,24 @@ class ProjectLoaderDynamic: except (json.JSONDecodeError, TypeError): keys = [k.strip() for k in output_keys.split(",") if k.strip()] + # Parse types for coercion + types = [] + if output_types: + try: + types = json.loads(output_types) + except (json.JSONDecodeError, TypeError): + types = [t.strip() for t in output_types.split(",")] + results = [] - for key in keys: + for i, key in enumerate(keys): val = data.get(key, "") - if isinstance(val, bool): + declared_type = types[i] if i < len(types) else "" + # Coerce based on declared output type when possible + if declared_type == "INT": + results.append(to_int(val)) + elif declared_type == "FLOAT": + results.append(to_float(val)) + elif isinstance(val, bool): results.append(str(val).lower()) elif isinstance(val, int): results.append(val) diff --git a/tab_projects_ng.py b/tab_projects_ng.py index afa8422..32494ac 100644 --- a/tab_projects_ng.py +++ b/tab_projects_ng.py @@ -119,6 +119,10 @@ def render_projects_tab(state: AppState): state.db.delete_project(name) if state.current_project == name: state.current_project = '' + state.config['current_project'] = '' + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) ui.notify(f'Deleted project "{name}"', type='positive') render_project_list.refresh() diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 41399ca..58e76ee 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -100,6 +100,22 @@ class TestProjectLoaderDynamic: assert result[0] == "comma_val" assert result[1] == "ok" + def test_load_dynamic_type_coercion(self): + """output_types should coerce values to declared types.""" + import json as _json + data = {"seed": "42", "cfg": "1.5", "prompt": "hello"} + node = ProjectLoaderDynamic() + keys_json = _json.dumps(["seed", "cfg", "prompt"]) + types_json = _json.dumps(["INT", "FLOAT", "STRING"]) + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json, output_types=types_json + ) + assert result[0] == 42 # string "42" coerced to int + assert result[1] == 1.5 # string "1.5" coerced to float + assert result[2] == "hello" # string stays string + def test_load_dynamic_empty_keys(self): node = ProjectLoaderDynamic() with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): diff --git a/utils.py b/utils.py index 809d58f..805ec79 100644 --- a/utils.py +++ b/utils.py @@ -202,6 +202,8 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: if isinstance(batch_data, list): db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) for item in batch_data: + if not isinstance(item, dict): + continue seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) now = __import__('time').time() db.conn.execute( diff --git a/web/project_dynamic.js b/web/project_dynamic.js index ed830a8..f8e7634 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -154,7 +154,7 @@ app.registerExtension({ let types = []; if (otWidget?.value) { try { types = JSON.parse(otWidget.value); } catch (_) { - types = otWidget.value.split(","); + types = otWidget.value.split(",").map(t => t.trim()).filter(Boolean); } } @@ -162,7 +162,7 @@ app.registerExtension({ // On load, LiteGraph already restored serialized outputs with links. // Rename and set types to match stored state (preserves links). for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i].trim(); + this.outputs[i].name = keys[i]; if (types[i]) this.outputs[i].type = types[i]; } From d07a30886576caf1b0bef79dba4021e0db6a6830 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:44:12 +0100 Subject: [PATCH 06/14] Harden ROLLBACK against I/O errors in transactions If the original error (e.g., disk full) also prevents ROLLBACK from executing, catch and suppress the ROLLBACK failure so the original exception propagates cleanly and the connection isn't left in a permanently broken state. Co-Authored-By: Claude Opus 4.6 --- db.py | 5 ++++- utils.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/db.py b/db.py index 000f577..cf428e5 100644 --- a/db.py +++ b/db.py @@ -295,7 +295,10 @@ class ProjectDB: self.conn.execute("COMMIT") return df_id except Exception: - self.conn.execute("ROLLBACK") + try: + self.conn.execute("ROLLBACK") + except Exception: + pass raise # ------------------------------------------------------------------ diff --git a/utils.py b/utils.py index 805ec79..af80c60 100644 --- a/utils.py +++ b/utils.py @@ -225,7 +225,10 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: db.conn.execute("COMMIT") except Exception: - db.conn.execute("ROLLBACK") + try: + db.conn.execute("ROLLBACK") + except Exception: + pass raise except Exception as e: logger.warning(f"sync_to_db failed: {e}") From 4b5fff5c6e44a7c0a3d42ed38889c560c2dcabc0 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 22:16:08 +0100 Subject: [PATCH 07/14] Improve ProjectLoaderDynamic UX: single node, error feedback, auto-refresh Remove 3 redundant hardcoded nodes (Standard/VACE/LoRA), keeping only the Dynamic node. Add total_sequences INT output (slot 0) for loop counting. Add structured error handling: _fetch_json returns typed error dicts, load_dynamic raises RuntimeError with descriptive messages, JS shows red border/title on errors. Add 500ms debounced auto-refresh on widget changes. Add 404s for missing project/file in API endpoints. Co-Authored-By: Claude Opus 4.6 --- api_routes.py | 21 +++- db.py | 15 +++ project_loader.py | 142 ++++++---------------- tests/test_db.py | 19 +++ tests/test_project_loader.py | 228 +++++++++++++++-------------------- web/project_dynamic.js | 125 ++++++++++++++++--- 6 files changed, 295 insertions(+), 255 deletions(-) diff --git a/api_routes.py b/api_routes.py index 6d5b42e..62f8512 100644 --- a/api_routes.py +++ b/api_routes.py @@ -55,13 +55,26 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]: def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() - data = db.query_sequence_data(name, file_name, seq) + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + data = db.get_sequence(df["id"], seq) if data is None: - raise HTTPException(status_code=404, detail="Sequence not found") + raise HTTPException(status_code=404, detail=f"Sequence {seq} not found") return data def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() - keys, types = db.query_sequence_keys(name, file_name, seq) - return {"keys": keys, "types": types} + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + keys, types = db.get_sequence_keys(df["id"], seq) + total = db.count_sequences(df["id"]) + return {"keys": keys, "types": types, "total_sequences": total} diff --git a/db.py b/db.py index cf428e5..e9088f9 100644 --- a/db.py +++ b/db.py @@ -182,6 +182,21 @@ class ProjectDB: ).fetchall() return [r["sequence_number"] for r in rows] + def count_sequences(self, data_file_id: int) -> int: + """Return the number of sequences for a data file.""" + row = self.conn.execute( + "SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return row["cnt"] + + def query_total_sequences(self, project_name: str, file_name: str) -> int: + """Return total sequence count by project and file names.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return 0 + return self.count_sequences(df["id"]) + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: """Returns (keys, types) for a sequence's data dict.""" data = self.get_sequence(data_file_id, sequence_number) diff --git a/project_loader.py b/project_loader.py index cc52d2c..6420517 100644 --- a/project_loader.py +++ b/project_loader.py @@ -39,13 +39,31 @@ def to_int(val: Any) -> int: def _fetch_json(url: str) -> dict: - """Fetch JSON from a URL using stdlib urllib.""" + """Fetch JSON from a URL using stdlib urllib. + + On error, returns a dict with an "error" key describing the failure. + """ try: with urllib.request.urlopen(url, timeout=5) as resp: return json.loads(resp.read()) - except (urllib.error.URLError, json.JSONDecodeError, OSError) as e: - logger.warning(f"Failed to fetch {url}: {e}") - return {} + except urllib.error.HTTPError as e: + # HTTPError is a subclass of URLError — must be caught first + body = "" + try: + raw = e.read() + detail = json.loads(raw) + body = detail.get("detail", str(raw, "utf-8", errors="replace")) + except Exception: + body = str(e) + logger.warning(f"HTTP {e.code} from {url}: {body}") + return {"error": "http_error", "status": e.code, "message": body} + except (urllib.error.URLError, OSError) as e: + reason = str(e.reason) if hasattr(e, "reason") else str(e) + logger.warning(f"Network error fetching {url}: {reason}") + return {"error": "network_error", "message": reason} + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON from {url}: {e}") + return {"error": "parse_error", "message": str(e)} def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: @@ -100,6 +118,9 @@ if PromptServer is not None: except (ValueError, TypeError): seq = 1 data = _fetch_keys(manager_url, project, file_name, seq) + if data.get("error") in ("http_error", "network_error", "parse_error"): + status = data.get("status", 502) + return web.json_response(data, status=status) return web.json_response(data) @@ -124,15 +145,25 @@ class ProjectLoaderDynamic: }, } - RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) - RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) FUNCTION = "load_dynamic" CATEGORY = "utils/json/project" OUTPUT_NODE = False def load_dynamic(self, manager_url, project_name, file_name, sequence_number, output_keys="", output_types=""): + # Fetch keys metadata (includes total_sequences count) + keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number) + if keys_meta.get("error") in ("http_error", "network_error", "parse_error"): + msg = keys_meta.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch project keys: {msg}") + total_sequences = keys_meta.get("total_sequences", 0) + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + msg = data.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch sequence data: {msg}") # Parse keys — try JSON array first, fall back to comma-split for compat keys = [] @@ -171,111 +202,14 @@ class ProjectLoaderDynamic: while len(results) < MAX_DYNAMIC_OUTPUTS: results.append("") - return tuple(results) - - -# ========================================== -# 1. STANDARD NODE (Project-based I2V) -# ========================================== - -class ProjectLoaderStandard: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_standard" - CATEGORY = "utils/json/project" - - def load_standard(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), str(data.get("video file path", "")), - str(data.get("reference image path", "")), str(data.get("flf image path", "")) - ) - - -# ========================================== -# 2. VACE NODE (Project-based) -# ========================================== - -class ProjectLoaderVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_vace" - CATEGORY = "utils/json/project" - - def load_vace(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)), - to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)), - str(data.get("reference path", "")), to_int(data.get("reference switch", 1)), - to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")), - str(data.get("reference image path", "")) - ) - - -# ========================================== -# 3. LoRA NODE (Project-based) -# ========================================== - -class ProjectLoaderLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": { - "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), - "project_name": ("STRING", {"default": "", "multiline": False}), - "file_name": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_loras" - CATEGORY = "utils/json/project" - - def load_loras(self, manager_url, project_name, file_name, sequence_number): - data = _fetch_data(manager_url, project_name, file_name, sequence_number) - return ( - str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), - str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), - str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) - ) + return (total_sequences,) + tuple(results) # --- Mappings --- PROJECT_NODE_CLASS_MAPPINGS = { "ProjectLoaderDynamic": ProjectLoaderDynamic, - "ProjectLoaderStandard": ProjectLoaderStandard, - "ProjectLoaderVACE": ProjectLoaderVACE, - "ProjectLoaderLoRA": ProjectLoaderLoRA, } PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { "ProjectLoaderDynamic": "Project Loader (Dynamic)", - "ProjectLoaderStandard": "Project Loader (Standard/I2V)", - "ProjectLoaderVACE": "Project Loader (VACE Full)", - "ProjectLoaderLoRA": "Project Loader (LoRAs)", } diff --git a/tests/test_db.py b/tests/test_db.py index 341edcb..bea102f 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -172,6 +172,25 @@ class TestSequences: db.delete_sequences_for_file(df_id) assert db.list_sequences(df_id) == [] + def test_count_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.count_sequences(df_id) == 0 + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.upsert_sequence(df_id, 3, {"c": 3}) + assert db.count_sequences(df_id) == 3 + + def test_query_total_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + assert db.query_total_sequences("p1", "batch") == 2 + + def test_query_total_sequences_nonexistent(self, db): + assert db.query_total_sequences("nope", "nope") == 0 + # ------------------------------------------------------------------ # History trees diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 58e76ee..dab80b2 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -6,9 +6,6 @@ import pytest from project_loader import ( ProjectLoaderDynamic, - ProjectLoaderStandard, - ProjectLoaderVACE, - ProjectLoaderLoRA, _fetch_json, _fetch_data, _fetch_keys, @@ -32,11 +29,23 @@ class TestFetchHelpers: result = _fetch_json("http://example.com/api") assert result == data - def test_fetch_json_failure(self): - import urllib.error + def test_fetch_json_network_error(self): with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): result = _fetch_json("http://example.com/api") - assert result == {} + assert result["error"] == "network_error" + assert "connection refused" in result["message"] + + def test_fetch_json_http_error(self): + import urllib.error + err = urllib.error.HTTPError( + "http://example.com/api", 404, "Not Found", {}, + BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode()) + ) + with patch("project_loader.urllib.request.urlopen", side_effect=err): + result = _fetch_json("http://example.com/api") + assert result["error"] == "http_error" + assert result["status"] == 404 + assert "not found" in result["message"].lower() def test_fetch_data_builds_url(self): data = {"prompt": "hello"} @@ -73,18 +82,23 @@ class TestFetchHelpers: class TestProjectLoaderDynamic: + def _keys_meta(self, total=5): + return {"keys": [], "types": [], "total_sequences": total} + def test_load_dynamic_with_keys(self): data = {"prompt": "hello", "seed": 42, "cfg": 1.5} node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="prompt,seed,cfg" - ) - assert result[0] == "hello" - assert result[1] == 42 - assert result[2] == 1.5 - assert len(result) == MAX_DYNAMIC_OUTPUTS + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="prompt,seed,cfg" + ) + assert result[0] == 5 # total_sequences + assert result[1] == "hello" + assert result[2] == 42 + assert result[3] == 1.5 + assert len(result) == MAX_DYNAMIC_OUTPUTS + 1 def test_load_dynamic_with_json_encoded_keys(self): """JSON-encoded output_keys should be parsed correctly.""" @@ -92,13 +106,14 @@ class TestProjectLoaderDynamic: data = {"my,key": "comma_val", "normal": "ok"} node = ProjectLoaderDynamic() keys_json = _json.dumps(["my,key", "normal"]) - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys=keys_json - ) - assert result[0] == "comma_val" - assert result[1] == "ok" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json + ) + assert result[1] == "comma_val" + assert result[2] == "ok" def test_load_dynamic_type_coercion(self): """output_types should coerce values to declared types.""" @@ -107,41 +122,75 @@ class TestProjectLoaderDynamic: node = ProjectLoaderDynamic() keys_json = _json.dumps(["seed", "cfg", "prompt"]) types_json = _json.dumps(["INT", "FLOAT", "STRING"]) - with patch("project_loader._fetch_data", return_value=data): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys=keys_json, output_types=types_json - ) - assert result[0] == 42 # string "42" coerced to int - assert result[1] == 1.5 # string "1.5" coerced to float - assert result[2] == "hello" # string stays string + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json, output_types=types_json + ) + assert result[1] == 42 # string "42" coerced to int + assert result[2] == 1.5 # string "1.5" coerced to float + assert result[3] == "hello" # string stays string def test_load_dynamic_empty_keys(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="" - ) - assert all(v == "" for v in result) + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + # Slot 0 is total_sequences (INT), rest are empty strings + assert result[0] == 5 + assert all(v == "" for v in result[1:]) def test_load_dynamic_missing_key(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="nonexistent" - ) - assert result[0] == "" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="nonexistent" + ) + assert result[1] == "" def test_load_dynamic_bool_becomes_string(self): node = ProjectLoaderDynamic() - with patch("project_loader._fetch_data", return_value={"flag": True}): - result = node.load_dynamic( - "http://localhost:8080", "proj1", "batch_i2v", 1, - output_keys="flag" - ) - assert result[0] == "true" + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"flag": True}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="flag" + ) + assert result[1] == "true" + + def test_load_dynamic_returns_total_sequences(self): + """total_sequences should be the first output from keys metadata.""" + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}): + with patch("project_loader._fetch_data", return_value={}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + assert result[0] == 42 + + def test_load_dynamic_raises_on_network_error(self): + """Network errors from _fetch_keys should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_keys", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch project keys"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) + + def test_load_dynamic_raises_on_data_fetch_error(self): + """Network errors from _fetch_data should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"} + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch sequence data"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) def test_input_types_has_manager_url(self): inputs = ProjectLoaderDynamic.INPUT_TYPES() @@ -154,88 +203,9 @@ class TestProjectLoaderDynamic: assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" -class TestProjectLoaderStandard: - def test_load_standard(self): - data = { - "general_prompt": "hello", - "general_negative": "bad", - "current_prompt": "specific", - "negative": "neg", - "camera": "pan", - "flf": 0.5, - "seed": 42, - "video file path": "/v.mp4", - "reference image path": "/r.png", - "flf image path": "/f.png", - } - node = ProjectLoaderStandard() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) - assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png") - - def test_load_standard_defaults(self): - node = ProjectLoaderStandard() - with patch("project_loader._fetch_data", return_value={}): - result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) - assert result[0] == "" # general_prompt - assert result[5] == 0.0 # flf - assert result[6] == 0 # seed - - -class TestProjectLoaderVACE: - def test_load_vace(self): - data = { - "general_prompt": "hello", - "general_negative": "bad", - "current_prompt": "specific", - "negative": "neg", - "camera": "pan", - "flf": 0.5, - "seed": 42, - "frame_to_skip": 81, - "input_a_frames": 16, - "input_b_frames": 16, - "reference path": "/ref", - "reference switch": 1, - "vace schedule": 2, - "video file path": "/v.mp4", - "reference image path": "/r.png", - } - node = ProjectLoaderVACE() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_vace("http://localhost:8080", "proj1", "batch", 1) - assert result[7] == 81 # frame_to_skip - assert result[12] == 2 # vace_schedule - - -class TestProjectLoaderLoRA: - def test_load_loras(self): - data = { - "lora 1 high": "", - "lora 1 low": "", - "lora 2 high": "", - "lora 2 low": "", - "lora 3 high": "", - "lora 3 low": "", - } - node = ProjectLoaderLoRA() - with patch("project_loader._fetch_data", return_value=data): - result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) - assert result[0] == "" - assert result[1] == "" - - def test_load_loras_empty(self): - node = ProjectLoaderLoRA() - with patch("project_loader._fetch_data", return_value={}): - result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) - assert all(v == "" for v in result) - - class TestNodeMappings: def test_mappings_exist(self): from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS - assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS - assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1 diff --git a/web/project_dynamic.js b/web/project_dynamic.js index f8e7634..783b664 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -32,18 +32,55 @@ app.registerExtension({ this.refreshDynamicOutputs(); }); + // Auto-refresh with 500ms debounce on widget changes + this._refreshTimer = null; + const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"]; + for (const widgetName of autoRefreshWidgets) { + const w = this.widgets?.find(w => w.name === widgetName); + if (w) { + const origCallback = w.callback; + const node = this; + w.callback = function (...args) { + origCallback?.apply(this, args); + clearTimeout(node._refreshTimer); + node._refreshTimer = setTimeout(() => { + node.refreshDynamicOutputs(); + }, 500); + }; + } + } + queueMicrotask(() => { if (!this._configured) { - // New node (not loading) — remove the 32 Python default outputs + // New node (not loading) — remove the Python default outputs + // and add only the fixed total_sequences slot while (this.outputs.length > 0) { this.removeOutput(0); } + this.addOutput("total_sequences", "INT"); this.setSize(this.computeSize()); app.graph?.setDirtyCanvas(true, true); } }); }; + nodeType.prototype._setStatus = function (status, message) { + const baseTitle = "Project Loader (Dynamic)"; + if (status === "ok") { + this.title = baseTitle; + this.color = undefined; + this.bgcolor = undefined; + } else if (status === "error") { + this.title = baseTitle + " - ERROR"; + this.color = "#ff4444"; + this.bgcolor = "#331111"; + if (message) this.title = baseTitle + ": " + message; + } else if (status === "loading") { + this.title = baseTitle + " - Loading..."; + } + app.graph?.setDirtyCanvas(true, true); + }; + nodeType.prototype.refreshDynamicOutputs = async function () { const urlWidget = this.widgets?.find(w => w.name === "manager_url"); const projectWidget = this.widgets?.find(w => w.name === "project_name"); @@ -52,13 +89,20 @@ app.registerExtension({ if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; + this._setStatus("loading"); + try { const resp = await api.fetchApi( `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` ); if (!resp.ok) { - console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs"); + let errorMsg = `HTTP ${resp.status}`; + try { + const errData = await resp.json(); + if (errData.message) errorMsg = errData.message; + } catch (_) {} + this._setStatus("error", errorMsg); return; } @@ -68,7 +112,8 @@ app.registerExtension({ // If the API returned an error or missing data, keep existing outputs and links intact if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { - console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types"); + const errMsg = data.error ? data.message || data.error : "Missing keys/types"; + this._setStatus("error", errMsg); return; } @@ -78,14 +123,20 @@ app.registerExtension({ const otWidget = this.widgets?.find(w => w.name === "output_types"); if (otWidget) otWidget.value = JSON.stringify(types); - // Build a map of current output names to slot indices + // Slot 0 is always total_sequences (INT) — ensure it exists + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + } + this.outputs[0].type = "INT"; + + // Build a map of current dynamic output names to slot indices (skip slot 0) const oldSlots = {}; - for (let i = 0; i < this.outputs.length; i++) { + for (let i = 1; i < this.outputs.length; i++) { oldSlots[this.outputs[i].name] = i; } - // Build new outputs, reusing existing slots to preserve links - const newOutputs = []; + // Build new dynamic outputs, reusing existing slots to preserve links + const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0 for (let k = 0; k < keys.length; k++) { const key = keys[k]; const type = types[k] || "*"; @@ -122,10 +173,12 @@ app.registerExtension({ } } + this._setStatus("ok"); this.setSize(this.computeSize()); app.graph?.setDirtyCanvas(true, true); } catch (e) { console.error("[ProjectLoaderDynamic] Refresh failed:", e); + this._setStatus("error", "Server unreachable"); } }; @@ -158,23 +211,59 @@ app.registerExtension({ } } + // Ensure slot 0 is total_sequences (INT) + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + // LiteGraph restores links AFTER onConfigure, so graph.links is + // empty here. Defer link fixup to a microtask that runs after the + // synchronous graph.configure() finishes (including link restoration). + // We must also rebuild output.links arrays because LiteGraph will + // place link IDs on the wrong outputs (shifted by the unshift above). + const node = this; + queueMicrotask(() => { + if (!node.graph) return; + // Clear all output.links — they were populated at old indices + for (const output of node.outputs) { + output.links = null; + } + // Rebuild from graph.links with corrected origin_slot (+1) + for (const linkId in node.graph.links) { + const link = node.graph.links[linkId]; + if (!link || link.origin_id !== node.id) continue; + link.origin_slot += 1; + const output = node.outputs[link.origin_slot]; + if (output) { + if (!output.links) output.links = []; + output.links.push(link.id); + } + } + app.graph?.setDirtyCanvas(true, true); + }); + } + this.outputs[0].type = "INT"; + this.outputs[0].name = "total_sequences"; + if (keys.length > 0) { // On load, LiteGraph already restored serialized outputs with links. - // Rename and set types to match stored state (preserves links). - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i]; - if (types[i]) this.outputs[i].type = types[i]; + // Dynamic outputs start at slot 1. Rename and set types to match stored state. + for (let i = 0; i < keys.length; i++) { + const slotIdx = i + 1; // offset by 1 for total_sequences + if (slotIdx < this.outputs.length) { + this.outputs[slotIdx].name = keys[i]; + if (types[i]) this.outputs[slotIdx].type = types[i]; + } } - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { + // Remove any extra outputs beyond keys + total_sequences + while (this.outputs.length > keys.length + 1) { this.removeOutput(this.outputs.length - 1); } - } else if (this.outputs.length > 0) { - // Widget values empty but serialized outputs exist — sync widgets - // from the outputs LiteGraph already restored (fallback). - if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name)); - if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type)); + } else if (this.outputs.length > 1) { + // Widget values empty but serialized dynamic outputs exist — sync widgets + // from the outputs LiteGraph already restored (fallback, skip slot 0). + const dynamicOutputs = this.outputs.slice(1); + if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name)); + if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type)); } this.setSize(this.computeSize()); From 86693f608afb6087033085a6d1b11dfad710302e Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 22:33:51 +0100 Subject: [PATCH 08/14] Remove 9 redundant JSON loader nodes, keep only JSONLoaderDynamic JSONLoaderDynamic auto-discovers keys at runtime, making the hardcoded Standard, Batch, and Custom nodes unnecessary. Co-Authored-By: Claude Opus 4.6 --- json_loader.py | 233 ------------------------------------------------- 1 file changed, 233 deletions(-) diff --git a/json_loader.py b/json_loader.py index d780a52..24f8e50 100644 --- a/json_loader.py +++ b/json_loader.py @@ -143,244 +143,11 @@ class JSONLoaderDynamic: return tuple(results) -# ========================================== -# 1. STANDARD NODES (Single File) -# ========================================== - -class JSONLoaderLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_loras" - CATEGORY = "utils/json" - - def load_loras(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), - str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), - str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) - ) - -class JSONLoaderStandard: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_standard" - CATEGORY = "utils/json" - - def load_standard(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), str(data.get("video file path", "")), - str(data.get("reference image path", "")), str(data.get("flf image path", "")) - ) - -class JSONLoaderVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_vace" - CATEGORY = "utils/json" - - def load_vace(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), - to_int(data.get("frame_to_skip", 81)), to_int(data.get("input_a_frames", 16)), - to_int(data.get("input_b_frames", 16)), str(data.get("reference path", "")), - to_int(data.get("reference switch", 1)), to_int(data.get("vace schedule", 1)), - str(data.get("video file path", "")), str(data.get("reference image path", "")) - ) - -# ========================================== -# 2. BATCH NODES -# ========================================== - -class JSONLoaderBatchLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_batch_loras" - CATEGORY = "utils/json" - - def load_batch_loras(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get("lora 1 high", "")), str(target_data.get("lora 1 low", "")), - str(target_data.get("lora 2 high", "")), str(target_data.get("lora 2 low", "")), - str(target_data.get("lora 3 high", "")), str(target_data.get("lora 3 low", "")) - ) - -class JSONLoaderBatchI2V: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_batch_i2v" - CATEGORY = "utils/json" - - def load_batch_i2v(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - - return ( - str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")), - str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")), - str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)), - to_int(target_data.get("seed", 0)), str(target_data.get("video file path", "")), - str(target_data.get("reference image path", "")), str(target_data.get("flf image path", "")) - ) - -class JSONLoaderBatchVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_batch_vace" - CATEGORY = "utils/json" - - def load_batch_vace(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - - return ( - str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")), - str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")), - str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)), - to_int(target_data.get("seed", 0)), to_int(target_data.get("frame_to_skip", 81)), - to_int(target_data.get("input_a_frames", 16)), to_int(target_data.get("input_b_frames", 16)), - str(target_data.get("reference path", "")), to_int(target_data.get("reference switch", 1)), - to_int(target_data.get("vace schedule", 1)), str(target_data.get("video file path", "")), - str(target_data.get("reference image path", "")) - ) - -# ========================================== -# 3. UNIVERSAL CUSTOM NODES (1, 3, 6 Slots) -# ========================================== - -class JSONLoaderCustom1: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { "key_1": ("STRING", {"default": "", "multiline": False}) } - } - RETURN_TYPES = ("STRING",) - RETURN_NAMES = ("val_1",) - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return (str(target_data.get(key_1, "")),) - -class JSONLoaderCustom3: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "key_1": ("STRING", {"default": "", "multiline": False}), - "key_2": ("STRING", {"default": "", "multiline": False}), - "key_3": ("STRING", {"default": "", "multiline": False}) - } - } - RETURN_TYPES = ("STRING", "STRING", "STRING") - RETURN_NAMES = ("val_1", "val_2", "val_3") - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get(key_1, "")), - str(target_data.get(key_2, "")), - str(target_data.get(key_3, "")) - ) - -class JSONLoaderCustom6: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "key_1": ("STRING", {"default": "", "multiline": False}), - "key_2": ("STRING", {"default": "", "multiline": False}), - "key_3": ("STRING", {"default": "", "multiline": False}), - "key_4": ("STRING", {"default": "", "multiline": False}), - "key_5": ("STRING", {"default": "", "multiline": False}), - "key_6": ("STRING", {"default": "", "multiline": False}) - } - } - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("val_1", "val_2", "val_3", "val_4", "val_5", "val_6") - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3="", key_4="", key_5="", key_6=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get(key_1, "")), str(target_data.get(key_2, "")), - str(target_data.get(key_3, "")), str(target_data.get(key_4, "")), - str(target_data.get(key_5, "")), str(target_data.get(key_6, "")) - ) - # --- Mappings --- NODE_CLASS_MAPPINGS = { "JSONLoaderDynamic": JSONLoaderDynamic, - "JSONLoaderLoRA": JSONLoaderLoRA, - "JSONLoaderStandard": JSONLoaderStandard, - "JSONLoaderVACE": JSONLoaderVACE, - "JSONLoaderBatchLoRA": JSONLoaderBatchLoRA, - "JSONLoaderBatchI2V": JSONLoaderBatchI2V, - "JSONLoaderBatchVACE": JSONLoaderBatchVACE, - "JSONLoaderCustom1": JSONLoaderCustom1, - "JSONLoaderCustom3": JSONLoaderCustom3, - "JSONLoaderCustom6": JSONLoaderCustom6 } NODE_DISPLAY_NAME_MAPPINGS = { "JSONLoaderDynamic": "JSON Loader (Dynamic)", - "JSONLoaderLoRA": "JSON Loader (LoRAs Only)", - "JSONLoaderStandard": "JSON Loader (Standard/I2V)", - "JSONLoaderVACE": "JSON Loader (VACE Full)", - "JSONLoaderBatchLoRA": "JSON Batch Loader (LoRAs)", - "JSONLoaderBatchI2V": "JSON Batch Loader (I2V)", - "JSONLoaderBatchVACE": "JSON Batch Loader (VACE)", - "JSONLoaderCustom1": "JSON Loader (Custom 1)", - "JSONLoaderCustom3": "JSON Loader (Custom 3)", - "JSONLoaderCustom6": "JSON Loader (Custom 6)" } From 027ef8e78a453908d5525a7fd721d6c129914380 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 23:04:14 +0100 Subject: [PATCH 09/14] Fix ProjectLoaderDynamic output names lost on page reload Hidden widget values for output_keys/output_types were not reliably restored by ComfyUI on workflow reload. Store keys/types in node.properties (always persisted by LiteGraph) as primary storage, with hidden widgets as fallback. Co-Authored-By: Claude Opus 4.6 --- web/project_dynamic.js | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 783b664..801d5e1 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -117,7 +117,12 @@ app.registerExtension({ return; } - // Store keys and types in hidden widgets for persistence (JSON-encoded) + // Store keys and types for persistence + // Properties are always reliably serialized by LiteGraph + this.properties = this.properties || {}; + this.properties._output_keys = keys; + this.properties._output_types = types; + // Also update hidden widgets for Python-side access const okWidget = this.widgets?.find(w => w.name === "output_keys"); if (okWidget) okWidget.value = JSON.stringify(keys); const otWidget = this.widgets?.find(w => w.name === "output_types"); @@ -197,15 +202,19 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - // Parse keys/types — try JSON array first, fall back to comma-split + // Read keys/types — properties (always persisted) first, then widgets let keys = []; - if (okWidget?.value) { + if (Array.isArray(this.properties?._output_keys) && this.properties._output_keys.length > 0) { + keys = this.properties._output_keys; + } else if (okWidget?.value) { try { keys = JSON.parse(okWidget.value); } catch (_) { keys = okWidget.value.split(",").map(k => k.trim()).filter(Boolean); } } let types = []; - if (otWidget?.value) { + if (Array.isArray(this.properties?._output_types) && this.properties._output_types.length > 0) { + types = this.properties._output_types; + } else if (otWidget?.value) { try { types = JSON.parse(otWidget.value); } catch (_) { types = otWidget.value.split(",").map(t => t.trim()).filter(Boolean); } @@ -259,11 +268,16 @@ app.registerExtension({ this.removeOutput(this.outputs.length - 1); } } else if (this.outputs.length > 1) { - // Widget values empty but serialized dynamic outputs exist — sync widgets - // from the outputs LiteGraph already restored (fallback, skip slot 0). + // Widget/property values empty but serialized dynamic outputs exist — + // sync from the outputs LiteGraph already restored (fallback, skip slot 0). const dynamicOutputs = this.outputs.slice(1); - if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name)); - if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type)); + const restoredKeys = dynamicOutputs.map(o => o.name); + const restoredTypes = dynamicOutputs.map(o => o.type); + this.properties = this.properties || {}; + this.properties._output_keys = restoredKeys; + this.properties._output_types = restoredTypes; + if (okWidget) okWidget.value = JSON.stringify(restoredKeys); + if (otWidget) otWidget.value = JSON.stringify(restoredTypes); } this.setSize(this.computeSize()); From 5b71d1b276416fb922dc074b8178ae5df95d355a Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 23:11:04 +0100 Subject: [PATCH 10/14] Fix output name persistence: use comma-separated like reference impl JSON.stringify format for hidden widget values didn't survive ComfyUI's serialization round-trip. Switch to comma-separated strings matching the proven ComfyUI-JSON-Dynamic implementation. Remove properties-based approach in favor of the simpler, working pattern. Co-Authored-By: Claude Opus 4.6 --- web/project_dynamic.js | 49 ++++++++++++------------------------------ 1 file changed, 14 insertions(+), 35 deletions(-) diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 801d5e1..c510f1e 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -117,16 +117,11 @@ app.registerExtension({ return; } - // Store keys and types for persistence - // Properties are always reliably serialized by LiteGraph - this.properties = this.properties || {}; - this.properties._output_keys = keys; - this.properties._output_types = types; - // Also update hidden widgets for Python-side access + // Store keys and types in hidden widgets for persistence (comma-separated) const okWidget = this.widgets?.find(w => w.name === "output_keys"); - if (okWidget) okWidget.value = JSON.stringify(keys); + if (okWidget) okWidget.value = keys.join(","); const otWidget = this.widgets?.find(w => w.name === "output_types"); - if (otWidget) otWidget.value = JSON.stringify(types); + if (otWidget) otWidget.value = types.join(","); // Slot 0 is always total_sequences (INT) — ensure it exists if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { @@ -202,23 +197,12 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - // Read keys/types — properties (always persisted) first, then widgets - let keys = []; - if (Array.isArray(this.properties?._output_keys) && this.properties._output_keys.length > 0) { - keys = this.properties._output_keys; - } else if (okWidget?.value) { - try { keys = JSON.parse(okWidget.value); } catch (_) { - keys = okWidget.value.split(",").map(k => k.trim()).filter(Boolean); - } - } - let types = []; - if (Array.isArray(this.properties?._output_types) && this.properties._output_types.length > 0) { - types = this.properties._output_types; - } else if (otWidget?.value) { - try { types = JSON.parse(otWidget.value); } catch (_) { - types = otWidget.value.split(",").map(t => t.trim()).filter(Boolean); - } - } + const keys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + const types = otWidget?.value + ? otWidget.value.split(",") + : []; // Ensure slot 0 is total_sequences (INT) if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { @@ -258,7 +242,7 @@ app.registerExtension({ for (let i = 0; i < keys.length; i++) { const slotIdx = i + 1; // offset by 1 for total_sequences if (slotIdx < this.outputs.length) { - this.outputs[slotIdx].name = keys[i]; + this.outputs[slotIdx].name = keys[i].trim(); if (types[i]) this.outputs[slotIdx].type = types[i]; } } @@ -268,16 +252,11 @@ app.registerExtension({ this.removeOutput(this.outputs.length - 1); } } else if (this.outputs.length > 1) { - // Widget/property values empty but serialized dynamic outputs exist — - // sync from the outputs LiteGraph already restored (fallback, skip slot 0). + // Widget values empty but serialized dynamic outputs exist — sync widgets + // from the outputs LiteGraph already restored (fallback, skip slot 0). const dynamicOutputs = this.outputs.slice(1); - const restoredKeys = dynamicOutputs.map(o => o.name); - const restoredTypes = dynamicOutputs.map(o => o.type); - this.properties = this.properties || {}; - this.properties._output_keys = restoredKeys; - this.properties._output_types = restoredTypes; - if (okWidget) okWidget.value = JSON.stringify(restoredKeys); - if (otWidget) otWidget.value = JSON.stringify(restoredTypes); + if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); } this.setSize(this.computeSize()); From bf2fca53e0a50713af33ae40ea220f96752f9247 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 23:18:39 +0100 Subject: [PATCH 11/14] Remove JSONLoaderDynamic, handled by ComfyUI-JSON-Dynamic extension The separate ComfyUI-JSON-Dynamic extension provides the same node. Removes json_loader.py, web/json_dynamic.js, and their tests. Only ProjectLoaderDynamic remains in this extension. Co-Authored-By: Claude Opus 4.6 --- __init__.py | 5 +- json_loader.py | 153 ---------------------------------- tests/test_json_loader.py | 165 ------------------------------------- web/json_dynamic.js | 168 -------------------------------------- 4 files changed, 2 insertions(+), 489 deletions(-) delete mode 100644 json_loader.py delete mode 100644 tests/test_json_loader.py delete mode 100644 web/json_dynamic.js diff --git a/__init__.py b/__init__.py index 5e657ee..1e2aadb 100644 --- a/__init__.py +++ b/__init__.py @@ -1,8 +1,7 @@ -from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS -NODE_CLASS_MAPPINGS.update(PROJECT_NODE_CLASS_MAPPINGS) -NODE_DISPLAY_NAME_MAPPINGS.update(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) +NODE_CLASS_MAPPINGS = PROJECT_NODE_CLASS_MAPPINGS +NODE_DISPLAY_NAME_MAPPINGS = PROJECT_NODE_DISPLAY_NAME_MAPPINGS WEB_DIRECTORY = "./web" diff --git a/json_loader.py b/json_loader.py deleted file mode 100644 index 24f8e50..0000000 --- a/json_loader.py +++ /dev/null @@ -1,153 +0,0 @@ -import json -import os -import logging -from typing import Any - -logger = logging.getLogger(__name__) - -KEY_BATCH_DATA = "batch_data" -MAX_DYNAMIC_OUTPUTS = 32 - - -class AnyType(str): - """Universal connector type that matches any ComfyUI type.""" - def __ne__(self, __value: object) -> bool: - return False - -any_type = AnyType("*") - - -try: - from server import PromptServer - from aiohttp import web -except ImportError: - PromptServer = None - - -def to_float(val: Any) -> float: - try: - return float(val) - except (ValueError, TypeError): - return 0.0 - -def to_int(val: Any) -> int: - try: - return int(float(val)) - except (ValueError, TypeError): - return 0 - -def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]: - """Resolve batch item by sequence_number field, falling back to array index.""" - if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0: - # Search by sequence_number field first - for item in data[KEY_BATCH_DATA]: - if int(item.get("sequence_number", 0)) == sequence_number: - return item - # Fallback to array index - idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1)) - logger.warning(f"No item with sequence_number={sequence_number}, falling back to index {idx}") - return data[KEY_BATCH_DATA][idx] - return data - -# --- Shared Helper --- -def read_json_data(json_path: str) -> dict[str, Any]: - if not os.path.exists(json_path): - logger.warning(f"File not found at {json_path}") - return {} - try: - with open(json_path, 'r') as f: - data = json.load(f) - except (json.JSONDecodeError, IOError) as e: - logger.warning(f"Error reading {json_path}: {e}") - return {} - if not isinstance(data, dict): - logger.warning(f"Expected dict from {json_path}, got {type(data).__name__}") - return {} - return data - -# --- API Route --- -if PromptServer is not None: - @PromptServer.instance.routes.get("/json_manager/get_keys") - async def get_keys_route(request): - json_path = request.query.get("path", "") - try: - seq = int(request.query.get("sequence_number", "1")) - except (ValueError, TypeError): - seq = 1 - data = read_json_data(json_path) - if not data: - return web.json_response({"keys": [], "types": [], "error": "file_not_found"}) - target = get_batch_item(data, seq) - keys = [] - types = [] - if isinstance(target, dict): - for k, v in target.items(): - keys.append(k) - if isinstance(v, bool): - types.append("STRING") - elif isinstance(v, int): - types.append("INT") - elif isinstance(v, float): - types.append("FLOAT") - else: - types.append("STRING") - return web.json_response({"keys": keys, "types": types}) - - -# ========================================== -# 0. DYNAMIC NODE -# ========================================== - -class JSONLoaderDynamic: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "output_keys": ("STRING", {"default": ""}), - "output_types": ("STRING", {"default": ""}), - }, - } - - RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) - RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) - FUNCTION = "load_dynamic" - CATEGORY = "utils/json" - OUTPUT_NODE = False - - def load_dynamic(self, json_path, sequence_number, output_keys="", output_types=""): - data = read_json_data(json_path) - target = get_batch_item(data, sequence_number) - - keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] - - results = [] - for key in keys: - val = target.get(key, "") - if isinstance(val, bool): - results.append(str(val).lower()) - elif isinstance(val, int): - results.append(val) - elif isinstance(val, float): - results.append(val) - else: - results.append(str(val)) - - # Pad to MAX_DYNAMIC_OUTPUTS - while len(results) < MAX_DYNAMIC_OUTPUTS: - results.append("") - - return tuple(results) - - -# --- Mappings --- -NODE_CLASS_MAPPINGS = { - "JSONLoaderDynamic": JSONLoaderDynamic, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "JSONLoaderDynamic": "JSON Loader (Dynamic)", -} diff --git a/tests/test_json_loader.py b/tests/test_json_loader.py deleted file mode 100644 index 2729e4d..0000000 --- a/tests/test_json_loader.py +++ /dev/null @@ -1,165 +0,0 @@ -import json -import os - -import pytest - -from json_loader import ( - to_float, to_int, get_batch_item, read_json_data, - JSONLoaderDynamic, MAX_DYNAMIC_OUTPUTS, -) - - -class TestToFloat: - def test_valid(self): - assert to_float("3.14") == 3.14 - assert to_float(5) == 5.0 - - def test_invalid(self): - assert to_float("abc") == 0.0 - - def test_none(self): - assert to_float(None) == 0.0 - - -class TestToInt: - def test_valid(self): - assert to_int("7") == 7 - assert to_int(3.9) == 3 - - def test_invalid(self): - assert to_int("xyz") == 0 - - def test_none(self): - assert to_int(None) == 0 - - -class TestGetBatchItem: - def test_lookup_by_sequence_number_field(self): - data = {"batch_data": [ - {"sequence_number": 1, "a": "first"}, - {"sequence_number": 5, "a": "fifth"}, - {"sequence_number": 3, "a": "third"}, - ]} - assert get_batch_item(data, 5) == {"sequence_number": 5, "a": "fifth"} - assert get_batch_item(data, 3) == {"sequence_number": 3, "a": "third"} - - def test_fallback_to_index(self): - data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]} - assert get_batch_item(data, 2) == {"a": 2} - - def test_clamp_high(self): - data = {"batch_data": [{"a": 1}, {"a": 2}]} - assert get_batch_item(data, 99) == {"a": 2} - - def test_clamp_low(self): - data = {"batch_data": [{"a": 1}, {"a": 2}]} - assert get_batch_item(data, 0) == {"a": 1} - - def test_no_batch_data(self): - data = {"key": "val"} - assert get_batch_item(data, 1) == data - - -class TestReadJsonData: - def test_missing_file(self, tmp_path): - assert read_json_data(str(tmp_path / "nope.json")) == {} - - def test_invalid_json(self, tmp_path): - p = tmp_path / "bad.json" - p.write_text("{broken") - assert read_json_data(str(p)) == {} - - def test_non_dict_json(self, tmp_path): - p = tmp_path / "list.json" - p.write_text(json.dumps([1, 2, 3])) - assert read_json_data(str(p)) == {} - - def test_valid(self, tmp_path): - p = tmp_path / "ok.json" - p.write_text(json.dumps({"key": "val"})) - assert read_json_data(str(p)) == {"key": "val"} - - -class TestJSONLoaderDynamic: - def _make_json(self, tmp_path, data): - p = tmp_path / "test.json" - p.write_text(json.dumps(data)) - return str(p) - - def test_known_keys(self, tmp_path): - path = self._make_json(tmp_path, {"name": "alice", "age": 30, "score": 9.5}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="name,age,score") - assert result[0] == "alice" - assert result[1] == 30 - assert result[2] == 9.5 - - def test_empty_output_keys(self, tmp_path): - path = self._make_json(tmp_path, {"name": "alice"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert all(v == "" for v in result) - - def test_pads_to_max(self, tmp_path): - path = self._make_json(tmp_path, {"a": "1", "b": "2"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="a,b") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert result[0] == "1" - assert result[1] == "2" - assert all(v == "" for v in result[2:]) - - def test_type_preservation_int(self, tmp_path): - path = self._make_json(tmp_path, {"count": 42}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="count") - assert result[0] == 42 - assert isinstance(result[0], int) - - def test_type_preservation_float(self, tmp_path): - path = self._make_json(tmp_path, {"rate": 3.14}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="rate") - assert result[0] == 3.14 - assert isinstance(result[0], float) - - def test_type_preservation_str(self, tmp_path): - path = self._make_json(tmp_path, {"label": "hello"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="label") - assert result[0] == "hello" - assert isinstance(result[0], str) - - def test_bool_becomes_string(self, tmp_path): - path = self._make_json(tmp_path, {"flag": True, "off": False}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="flag,off") - assert result[0] == "true" - assert result[1] == "false" - assert isinstance(result[0], str) - - def test_missing_key_returns_empty_string(self, tmp_path): - path = self._make_json(tmp_path, {"a": "1"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="a,nonexistent") - assert result[0] == "1" - assert result[1] == "" - - def test_missing_file_returns_all_empty(self, tmp_path): - loader = JSONLoaderDynamic() - result = loader.load_dynamic(str(tmp_path / "nope.json"), 1, output_keys="a,b") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert result[0] == "" - assert result[1] == "" - - def test_batch_data(self, tmp_path): - path = self._make_json(tmp_path, { - "batch_data": [ - {"sequence_number": 1, "x": "first"}, - {"sequence_number": 2, "x": "second"}, - ] - }) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 2, output_keys="x") - assert result[0] == "second" diff --git a/web/json_dynamic.js b/web/json_dynamic.js deleted file mode 100644 index 22ca527..0000000 --- a/web/json_dynamic.js +++ /dev/null @@ -1,168 +0,0 @@ -import { app } from "../../scripts/app.js"; -import { api } from "../../scripts/api.js"; - -app.registerExtension({ - name: "json.manager.dynamic", - - async beforeRegisterNodeDef(nodeType, nodeData, app) { - if (nodeData.name !== "JSONLoaderDynamic") return; - - const origOnNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = function () { - origOnNodeCreated?.apply(this, arguments); - - // Hide internal widgets (managed by JS) - for (const name of ["output_keys", "output_types"]) { - const w = this.widgets?.find(w => w.name === name); - if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } - } - - // Do NOT remove default outputs synchronously here. - // During graph loading, ComfyUI creates all nodes (firing onNodeCreated) - // before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve - // links to our outputs during their configure step. If we remove outputs - // here, those nodes find no output slot and error out. - // - // Instead, defer cleanup: for loaded workflows onConfigure sets _configured - // before this runs; for new nodes the defaults are cleaned up. - this._configured = false; - - // Add Refresh button - this.addWidget("button", "Refresh Outputs", null, () => { - this.refreshDynamicOutputs(); - }); - - queueMicrotask(() => { - if (!this._configured) { - // New node (not loading) — remove the 32 Python default outputs - while (this.outputs.length > 0) { - this.removeOutput(0); - } - this.setSize(this.computeSize()); - app.graph?.setDirtyCanvas(true, true); - } - }); - }; - - nodeType.prototype.refreshDynamicOutputs = async function () { - const pathWidget = this.widgets?.find(w => w.name === "json_path"); - const seqWidget = this.widgets?.find(w => w.name === "sequence_number"); - if (!pathWidget?.value) return; - - try { - const resp = await api.fetchApi( - `/json_manager/get_keys?path=${encodeURIComponent(pathWidget.value)}&sequence_number=${seqWidget?.value || 1}` - ); - const data = await resp.json(); - const { keys, types } = data; - - // If the file wasn't found, keep existing outputs and links intact - if (data.error === "file_not_found") { - console.warn("[JSONLoaderDynamic] File not found, keeping existing outputs:", pathWidget.value); - return; - } - - // Store keys and types in hidden widgets for persistence - const okWidget = this.widgets?.find(w => w.name === "output_keys"); - if (okWidget) okWidget.value = keys.join(","); - const otWidget = this.widgets?.find(w => w.name === "output_types"); - if (otWidget) otWidget.value = types.join(","); - - // Build a map of current output names to slot indices - const oldSlots = {}; - for (let i = 0; i < this.outputs.length; i++) { - oldSlots[this.outputs[i].name] = i; - } - - // Build new outputs, reusing existing slots to preserve links - const newOutputs = []; - for (let k = 0; k < keys.length; k++) { - const key = keys[k]; - const type = types[k] || "*"; - if (key in oldSlots) { - // Reuse existing slot object (keeps links intact) - const slot = this.outputs[oldSlots[key]]; - slot.type = type; - newOutputs.push(slot); - delete oldSlots[key]; - } else { - // New key — create a fresh slot - newOutputs.push({ name: key, type: type, links: null }); - } - } - - // Disconnect links on slots that are being removed - for (const name in oldSlots) { - const idx = oldSlots[name]; - if (this.outputs[idx]?.links?.length) { - for (const linkId of [...this.outputs[idx].links]) { - this.graph?.removeLink(linkId); - } - } - } - - // Reassign the outputs array and fix link slot indices - this.outputs = newOutputs; - if (this.graph) { - for (let i = 0; i < this.outputs.length; i++) { - const links = this.outputs[i].links; - if (!links) continue; - for (const linkId of links) { - const link = this.graph.links[linkId]; - if (link) link.origin_slot = i; - } - } - } - - this.setSize(this.computeSize()); - app.graph.setDirtyCanvas(true, true); - } catch (e) { - console.error("[JSONLoaderDynamic] Refresh failed:", e); - } - }; - - // Restore state on workflow load - const origOnConfigure = nodeType.prototype.onConfigure; - nodeType.prototype.onConfigure = function (info) { - origOnConfigure?.apply(this, arguments); - this._configured = true; - - // Hide internal widgets - for (const name of ["output_keys", "output_types"]) { - const w = this.widgets?.find(w => w.name === name); - if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } - } - - const okWidget = this.widgets?.find(w => w.name === "output_keys"); - const otWidget = this.widgets?.find(w => w.name === "output_types"); - - const keys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - const types = otWidget?.value - ? otWidget.value.split(",") - : []; - - if (keys.length > 0) { - // On load, LiteGraph already restored serialized outputs with links. - // Rename and set types to match stored state (preserves links). - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i].trim(); - if (types[i]) this.outputs[i].type = types[i]; - } - - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { - this.removeOutput(this.outputs.length - 1); - } - } else if (this.outputs.length > 0) { - // Widget values empty but serialized outputs exist — sync widgets - // from the outputs LiteGraph already restored (fallback). - if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(","); - if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(","); - } - - this.setSize(this.computeSize()); - }; - }, -}); From d55b3198e870e400c649f53445230ef66e2fca16 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 23:43:34 +0100 Subject: [PATCH 12/14] Fix output names not surviving page refresh in ProjectLoaderDynamic Read output names from info.outputs (serialized node data) instead of hidden widget values, which ComfyUI may not persist across reloads. Co-Authored-By: Claude Opus 4.6 --- web/project_dynamic.js | 62 ++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/web/project_dynamic.js b/web/project_dynamic.js index c510f1e..7a5e766 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -197,29 +197,47 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - const keys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - const types = otWidget?.value - ? otWidget.value.split(",") - : []; + // Primary source: read output names from serialized node info. + // Hidden widget values may not survive ComfyUI's serialization, + // but info.outputs always contains the correct saved output names. + let keys = []; + let types = []; + const savedOutputs = info.outputs || []; + for (let i = 0; i < savedOutputs.length; i++) { + if (savedOutputs[i].name === "total_sequences") continue; + if (/^output_\d+$/.test(savedOutputs[i].name)) continue; + keys.push(savedOutputs[i].name); + types.push(savedOutputs[i].type || "*"); + } + + // Fallback: try hidden widget values + if (keys.length === 0) { + const wKeys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + if (wKeys.length > 0) { + keys = wKeys; + types = otWidget?.value + ? otWidget.value.split(",") + : []; + } + } + + // Update hidden widgets so the Python backend has keys for execution + if (keys.length > 0) { + if (okWidget) okWidget.value = keys.join(","); + if (otWidget) otWidget.value = types.join(","); + } // Ensure slot 0 is total_sequences (INT) if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); - // LiteGraph restores links AFTER onConfigure, so graph.links is - // empty here. Defer link fixup to a microtask that runs after the - // synchronous graph.configure() finishes (including link restoration). - // We must also rebuild output.links arrays because LiteGraph will - // place link IDs on the wrong outputs (shifted by the unshift above). const node = this; queueMicrotask(() => { if (!node.graph) return; - // Clear all output.links — they were populated at old indices for (const output of node.outputs) { output.links = null; } - // Rebuild from graph.links with corrected origin_slot (+1) for (const linkId in node.graph.links) { const link = node.graph.links[linkId]; if (!link || link.origin_id !== node.id) continue; @@ -237,26 +255,24 @@ app.registerExtension({ this.outputs[0].name = "total_sequences"; if (keys.length > 0) { - // On load, LiteGraph already restored serialized outputs with links. - // Dynamic outputs start at slot 1. Rename and set types to match stored state. for (let i = 0; i < keys.length; i++) { - const slotIdx = i + 1; // offset by 1 for total_sequences + const slotIdx = i + 1; if (slotIdx < this.outputs.length) { this.outputs[slotIdx].name = keys[i].trim(); if (types[i]) this.outputs[slotIdx].type = types[i]; } } - - // Remove any extra outputs beyond keys + total_sequences while (this.outputs.length > keys.length + 1) { this.removeOutput(this.outputs.length - 1); } } else if (this.outputs.length > 1) { - // Widget values empty but serialized dynamic outputs exist — sync widgets - // from the outputs LiteGraph already restored (fallback, skip slot 0). - const dynamicOutputs = this.outputs.slice(1); - if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); - if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); + const dynamicOutputs = this.outputs.slice(1).filter( + o => !/^output_\d+$/.test(o.name) + ); + if (dynamicOutputs.length > 0) { + if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); + } } this.setSize(this.computeSize()); From a0d8cb8bbf12d6fadf40ab660cda1a5d5923f13b Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 1 Mar 2026 00:10:50 +0100 Subject: [PATCH 13/14] Fix: set output label alongside name for LiteGraph rendering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LiteGraph renders slot.label over slot.name — we were updating name but the display uses label. Co-Authored-By: Claude Opus 4.6 --- web/project_dynamic.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 7a5e766..15b525e 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -146,7 +146,7 @@ app.registerExtension({ newOutputs.push(slot); delete oldSlots[key]; } else { - newOutputs.push({ name: key, type: type, links: null }); + newOutputs.push({ name: key, label: key, type: type, links: null }); } } @@ -259,6 +259,7 @@ app.registerExtension({ const slotIdx = i + 1; if (slotIdx < this.outputs.length) { this.outputs[slotIdx].name = keys[i].trim(); + this.outputs[slotIdx].label = keys[i].trim(); if (types[i]) this.outputs[slotIdx].type = types[i]; } } From 187b85b054764bfd6cc8f34be5640b353bc08616 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 1 Mar 2026 00:12:30 +0100 Subject: [PATCH 14/14] Clean up: remove unnecessary info.outputs logic, set label on reused slots The actual fix was setting slot.label alongside slot.name. Reverted onConfigure to read from widget values (which work correctly) and ensured label is set on both new and reused output slots. Co-Authored-By: Claude Opus 4.6 --- web/project_dynamic.js | 49 ++++++++++-------------------------------- 1 file changed, 11 insertions(+), 38 deletions(-) diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 15b525e..7f2de6e 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -143,6 +143,7 @@ app.registerExtension({ if (key in oldSlots) { const slot = this.outputs[oldSlots[key]]; slot.type = type; + slot.label = key; newOutputs.push(slot); delete oldSlots[key]; } else { @@ -197,37 +198,12 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - // Primary source: read output names from serialized node info. - // Hidden widget values may not survive ComfyUI's serialization, - // but info.outputs always contains the correct saved output names. - let keys = []; - let types = []; - const savedOutputs = info.outputs || []; - for (let i = 0; i < savedOutputs.length; i++) { - if (savedOutputs[i].name === "total_sequences") continue; - if (/^output_\d+$/.test(savedOutputs[i].name)) continue; - keys.push(savedOutputs[i].name); - types.push(savedOutputs[i].type || "*"); - } - - // Fallback: try hidden widget values - if (keys.length === 0) { - const wKeys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - if (wKeys.length > 0) { - keys = wKeys; - types = otWidget?.value - ? otWidget.value.split(",") - : []; - } - } - - // Update hidden widgets so the Python backend has keys for execution - if (keys.length > 0) { - if (okWidget) okWidget.value = keys.join(","); - if (otWidget) otWidget.value = types.join(","); - } + const keys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + const types = otWidget?.value + ? otWidget.value.split(",") + : []; // Ensure slot 0 is total_sequences (INT) if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { @@ -267,13 +243,10 @@ app.registerExtension({ this.removeOutput(this.outputs.length - 1); } } else if (this.outputs.length > 1) { - const dynamicOutputs = this.outputs.slice(1).filter( - o => !/^output_\d+$/.test(o.name) - ); - if (dynamicOutputs.length > 0) { - if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); - if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); - } + // Widget values empty but serialized dynamic outputs exist — sync widgets + const dynamicOutputs = this.outputs.slice(1); + if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); } this.setSize(this.computeSize());