diff --git a/__init__.py b/__init__.py index 43198c8..1e2aadb 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,7 @@ -from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + +NODE_CLASS_MAPPINGS = PROJECT_NODE_CLASS_MAPPINGS +NODE_DISPLAY_NAME_MAPPINGS = PROJECT_NODE_DISPLAY_NAME_MAPPINGS WEB_DIRECTORY = "./web" diff --git a/api_routes.py b/api_routes.py new file mode 100644 index 0000000..62f8512 --- /dev/null +++ b/api_routes.py @@ -0,0 +1,80 @@ +"""REST API endpoints for ComfyUI to query project data from SQLite. + +All endpoints are read-only. Mounted on the NiceGUI/FastAPI server. +""" + +import logging +from typing import Any + +from fastapi import HTTPException, Query +from nicegui import app + +from db import ProjectDB + +logger = logging.getLogger(__name__) + +# The DB instance is set by register_api_routes() +_db: ProjectDB | None = None + + +def register_api_routes(db: ProjectDB) -> None: + """Register all REST API routes with the NiceGUI/FastAPI app.""" + global _db + _db = db + + app.add_api_route("/api/projects", _list_projects, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"]) + + +def _get_db() -> ProjectDB: + if _db is None: + raise HTTPException(status_code=503, detail="Database not initialized") + return _db + + +def _list_projects() -> dict[str, Any]: + db = _get_db() + projects = db.list_projects() + return {"projects": [p["name"] for p in projects]} + + +def _list_files(name: str) -> dict[str, Any]: + db = _get_db() + files = db.list_project_files(name) + return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]} + + +def _list_sequences(name: str, file_name: str) -> dict[str, Any]: + db = _get_db() + seqs = db.list_project_sequences(name, file_name) + return {"sequences": seqs} + + +def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + data = db.get_sequence(df["id"], seq) + if data is None: + raise HTTPException(status_code=404, detail=f"Sequence {seq} not found") + return data + + +def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + df = db.get_data_file_by_names(name, file_name) + if not df: + raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") + keys, types = db.get_sequence_keys(df["id"], seq) + total = db.count_sequences(df["id"]) + return {"keys": keys, "types": types, "total_sequences": total} diff --git a/db.py b/db.py new file mode 100644 index 0000000..e9088f9 --- /dev/null +++ b/db.py @@ -0,0 +1,349 @@ +import json +import logging +import sqlite3 +import time +from pathlib import Path +from typing import Any + +from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE + +logger = logging.getLogger(__name__) + +DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db" + +SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + folder_path TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + created_at REAL NOT NULL, + updated_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS data_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + name TEXT NOT NULL, + data_type TEXT NOT NULL DEFAULT 'generic', + top_level TEXT NOT NULL DEFAULT '{}', + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + UNIQUE(project_id, name) +); + +CREATE TABLE IF NOT EXISTS sequences ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE, + sequence_number INTEGER NOT NULL, + data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL, + UNIQUE(data_file_id, sequence_number) +); + +CREATE TABLE IF NOT EXISTS history_trees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE, + tree_data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL +); +""" + + +class ProjectDB: + """SQLite database for project-based data management.""" + + def __init__(self, db_path: str | Path | None = None): + self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.conn = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + isolation_level=None, # autocommit — explicit BEGIN/COMMIT only + ) + self.conn.row_factory = sqlite3.Row + self.conn.execute("PRAGMA journal_mode=WAL") + self.conn.execute("PRAGMA foreign_keys=ON") + self.conn.executescript(SCHEMA_SQL) + + def close(self): + self.conn.close() + + # ------------------------------------------------------------------ + # Projects CRUD + # ------------------------------------------------------------------ + + def create_project(self, name: str, folder_path: str, description: str = "") -> int: + now = time.time() + cur = self.conn.execute( + "INSERT INTO projects (name, folder_path, description, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?)", + (name, folder_path, description, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_projects(self) -> list[dict]: + rows = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects ORDER BY name" + ).fetchall() + return [dict(r) for r in rows] + + def get_project(self, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects WHERE name = ?", + (name,), + ).fetchone() + return dict(row) if row else None + + def delete_project(self, name: str) -> bool: + cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,)) + self.conn.commit() + return cur.rowcount > 0 + + # ------------------------------------------------------------------ + # Data files + # ------------------------------------------------------------------ + + def create_data_file( + self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None + ) -> int: + now = time.time() + tl = json.dumps(top_level or {}) + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, name, data_type, tl, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_data_files(self, project_id: int) -> list[dict]: + rows = self.conn.execute( + "SELECT id, project_id, name, data_type, created_at, updated_at " + "FROM data_files WHERE project_id = ? ORDER BY name", + (project_id,), + ).fetchall() + return [dict(r) for r in rows] + + def get_data_file(self, project_id: int, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, project_id, name, data_type, top_level, created_at, updated_at " + "FROM data_files WHERE project_id = ? AND name = ?", + (project_id, name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None: + row = self.conn.execute( + "SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, " + "df.created_at, df.updated_at " + "FROM data_files df JOIN projects p ON df.project_id = p.id " + "WHERE p.name = ? AND df.name = ?", + (project_name, file_name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + # ------------------------------------------------------------------ + # Sequences + # ------------------------------------------------------------------ + + def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (data_file_id, sequence_number, json.dumps(data), now), + ) + self.conn.commit() + + def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None: + row = self.conn.execute( + "SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?", + (data_file_id, sequence_number), + ).fetchone() + return json.loads(row["data"]) if row else None + + def list_sequences(self, data_file_id: int) -> list[int]: + rows = self.conn.execute( + "SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number", + (data_file_id,), + ).fetchall() + return [r["sequence_number"] for r in rows] + + def count_sequences(self, data_file_id: int) -> int: + """Return the number of sequences for a data file.""" + row = self.conn.execute( + "SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return row["cnt"] + + def query_total_sequences(self, project_name: str, file_name: str) -> int: + """Return total sequence count by project and file names.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return 0 + return self.count_sequences(df["id"]) + + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: + """Returns (keys, types) for a sequence's data dict.""" + data = self.get_sequence(data_file_id, sequence_number) + if not data: + return [], [] + keys = [] + types = [] + for k, v in data.items(): + keys.append(k) + if isinstance(v, bool): + types.append("STRING") + elif isinstance(v, int): + types.append("INT") + elif isinstance(v, float): + types.append("FLOAT") + else: + types.append("STRING") + return keys, types + + def delete_sequences_for_file(self, data_file_id: int) -> None: + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,)) + self.conn.commit() + + # ------------------------------------------------------------------ + # History trees + # ------------------------------------------------------------------ + + def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (data_file_id, json.dumps(tree_data), now), + ) + self.conn.commit() + + def get_history_tree(self, data_file_id: int) -> dict | None: + row = self.conn.execute( + "SELECT tree_data FROM history_trees WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return json.loads(row["tree_data"]) if row else None + + # ------------------------------------------------------------------ + # Import + # ------------------------------------------------------------------ + + def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: + """Import a JSON file into the database, splitting batch_data into sequences. + + Safe to call repeatedly — existing data_file is updated, sequences are + replaced, and history_tree is upserted. Atomic: all-or-nothing. + """ + json_path = Path(json_path) + data, _ = load_json(json_path) + file_name = json_path.stem + + top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + + self.conn.execute("BEGIN IMMEDIATE") + try: + existing = self.conn.execute( + "SELECT id FROM data_files WHERE project_id = ? AND name = ?", + (project_id, file_name), + ).fetchone() + + if existing: + df_id = existing["id"] + now = time.time() + self.conn.execute( + "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", + (data_type, json.dumps(top_level), now, df_id), + ) + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + else: + now = time.time() + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, file_name, data_type, json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid + + # Import sequences from batch_data + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + for item in batch_data: + if not isinstance(item, dict): + continue + seq_num = int(item.get("sequence_number", 0)) + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (df_id, seq_num, json.dumps(item), now), + ) + + # Import history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + self.conn.execute("COMMIT") + return df_id + except Exception: + try: + self.conn.execute("ROLLBACK") + except Exception: + pass + raise + + # ------------------------------------------------------------------ + # Query helpers (for REST API) + # ------------------------------------------------------------------ + + def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None: + """Query a single sequence by project name, file name, and sequence number.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return None + return self.get_sequence(df["id"], sequence_number) + + def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]: + """Query keys and types for a sequence.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [], [] + return self.get_sequence_keys(df["id"], sequence_number) + + def list_project_files(self, project_name: str) -> list[dict]: + """List data files for a project by name.""" + proj = self.get_project(project_name) + if not proj: + return [] + return self.list_data_files(proj["id"]) + + def list_project_sequences(self, project_name: str, file_name: str) -> list[int]: + """List sequence numbers for a file in a project.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [] + return self.list_sequences(df["id"]) diff --git a/json_loader.py b/json_loader.py deleted file mode 100644 index eed69fb..0000000 --- a/json_loader.py +++ /dev/null @@ -1,384 +0,0 @@ -import json -import os -import logging -from typing import Any - -logger = logging.getLogger(__name__) - -KEY_BATCH_DATA = "batch_data" -MAX_DYNAMIC_OUTPUTS = 32 - - -class AnyType(str): - """Universal connector type that matches any ComfyUI type.""" - def __ne__(self, __value: object) -> bool: - return False - -any_type = AnyType("*") - - -try: - from server import PromptServer - from aiohttp import web -except ImportError: - PromptServer = None - - -def to_float(val: Any) -> float: - try: - return float(val) - except (ValueError, TypeError): - return 0.0 - -def to_int(val: Any) -> int: - try: - return int(float(val)) - except (ValueError, TypeError): - return 0 - -def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]: - """Resolve batch item by sequence_number field, falling back to array index.""" - if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0: - # Search by sequence_number field first - for item in data[KEY_BATCH_DATA]: - if int(item.get("sequence_number", 0)) == sequence_number: - return item - # Fallback to array index - idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1)) - logger.warning(f"No item with sequence_number={sequence_number}, falling back to index {idx}") - return data[KEY_BATCH_DATA][idx] - return data - -# --- Shared Helper --- -def read_json_data(json_path: str) -> dict[str, Any]: - if not os.path.exists(json_path): - logger.warning(f"File not found at {json_path}") - return {} - try: - with open(json_path, 'r') as f: - data = json.load(f) - except (json.JSONDecodeError, IOError) as e: - logger.warning(f"Error reading {json_path}: {e}") - return {} - if not isinstance(data, dict): - logger.warning(f"Expected dict from {json_path}, got {type(data).__name__}") - return {} - return data - -# --- API Route --- -if PromptServer is not None: - @PromptServer.instance.routes.get("/json_manager/get_keys") - async def get_keys_route(request): - json_path = request.query.get("path", "") - try: - seq = int(request.query.get("sequence_number", "1")) - except (ValueError, TypeError): - seq = 1 - data = read_json_data(json_path) - target = get_batch_item(data, seq) - keys = [] - types = [] - if isinstance(target, dict): - for k, v in target.items(): - keys.append(k) - if isinstance(v, bool): - types.append("STRING") - elif isinstance(v, int): - types.append("INT") - elif isinstance(v, float): - types.append("FLOAT") - else: - types.append("STRING") - return web.json_response({"keys": keys, "types": types}) - - -# ========================================== -# 0. DYNAMIC NODE -# ========================================== - -class JSONLoaderDynamic: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "output_keys": ("STRING", {"default": ""}), - "output_types": ("STRING", {"default": ""}), - }, - } - - RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) - RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) - FUNCTION = "load_dynamic" - CATEGORY = "utils/json" - OUTPUT_NODE = False - - def load_dynamic(self, json_path, sequence_number, output_keys="", output_types=""): - data = read_json_data(json_path) - target = get_batch_item(data, sequence_number) - - keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] - - results = [] - for key in keys: - val = target.get(key, "") - if isinstance(val, bool): - results.append(str(val).lower()) - elif isinstance(val, int): - results.append(val) - elif isinstance(val, float): - results.append(val) - else: - results.append(str(val)) - - # Pad to MAX_DYNAMIC_OUTPUTS - while len(results) < MAX_DYNAMIC_OUTPUTS: - results.append("") - - return tuple(results) - - -# ========================================== -# 1. STANDARD NODES (Single File) -# ========================================== - -class JSONLoaderLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_loras" - CATEGORY = "utils/json" - - def load_loras(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), - str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), - str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) - ) - -class JSONLoaderStandard: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_standard" - CATEGORY = "utils/json" - - def load_standard(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), str(data.get("video file path", "")), - str(data.get("reference image path", "")), str(data.get("flf image path", "")) - ) - -class JSONLoaderVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}} - - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_vace" - CATEGORY = "utils/json" - - def load_vace(self, json_path): - data = read_json_data(json_path) - return ( - str(data.get("general_prompt", "")), str(data.get("general_negative", "")), - str(data.get("current_prompt", "")), str(data.get("negative", "")), - str(data.get("camera", "")), to_float(data.get("flf", 0.0)), - to_int(data.get("seed", 0)), - to_int(data.get("frame_to_skip", 81)), to_int(data.get("input_a_frames", 16)), - to_int(data.get("input_b_frames", 16)), str(data.get("reference path", "")), - to_int(data.get("reference switch", 1)), to_int(data.get("vace schedule", 1)), - str(data.get("video file path", "")), str(data.get("reference image path", "")) - ) - -# ========================================== -# 2. BATCH NODES -# ========================================== - -class JSONLoaderBatchLoRA: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") - FUNCTION = "load_batch_loras" - CATEGORY = "utils/json" - - def load_batch_loras(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get("lora 1 high", "")), str(target_data.get("lora 1 low", "")), - str(target_data.get("lora 2 high", "")), str(target_data.get("lora 2 low", "")), - str(target_data.get("lora 3 high", "")), str(target_data.get("lora 3 low", "")) - ) - -class JSONLoaderBatchI2V: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") - FUNCTION = "load_batch_i2v" - CATEGORY = "utils/json" - - def load_batch_i2v(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - - return ( - str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")), - str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")), - str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)), - to_int(target_data.get("seed", 0)), str(target_data.get("video file path", "")), - str(target_data.get("reference image path", "")), str(target_data.get("flf image path", "")) - ) - -class JSONLoaderBatchVACE: - @classmethod - def INPUT_TYPES(s): - return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}} - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") - RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") - FUNCTION = "load_batch_vace" - CATEGORY = "utils/json" - - def load_batch_vace(self, json_path, sequence_number): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - - return ( - str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")), - str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")), - str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)), - to_int(target_data.get("seed", 0)), to_int(target_data.get("frame_to_skip", 81)), - to_int(target_data.get("input_a_frames", 16)), to_int(target_data.get("input_b_frames", 16)), - str(target_data.get("reference path", "")), to_int(target_data.get("reference switch", 1)), - to_int(target_data.get("vace schedule", 1)), str(target_data.get("video file path", "")), - str(target_data.get("reference image path", "")) - ) - -# ========================================== -# 3. UNIVERSAL CUSTOM NODES (1, 3, 6 Slots) -# ========================================== - -class JSONLoaderCustom1: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { "key_1": ("STRING", {"default": "", "multiline": False}) } - } - RETURN_TYPES = ("STRING",) - RETURN_NAMES = ("val_1",) - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return (str(target_data.get(key_1, "")),) - -class JSONLoaderCustom3: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "key_1": ("STRING", {"default": "", "multiline": False}), - "key_2": ("STRING", {"default": "", "multiline": False}), - "key_3": ("STRING", {"default": "", "multiline": False}) - } - } - RETURN_TYPES = ("STRING", "STRING", "STRING") - RETURN_NAMES = ("val_1", "val_2", "val_3") - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get(key_1, "")), - str(target_data.get(key_2, "")), - str(target_data.get(key_3, "")) - ) - -class JSONLoaderCustom6: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "json_path": ("STRING", {"default": "", "multiline": False}), - "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), - }, - "optional": { - "key_1": ("STRING", {"default": "", "multiline": False}), - "key_2": ("STRING", {"default": "", "multiline": False}), - "key_3": ("STRING", {"default": "", "multiline": False}), - "key_4": ("STRING", {"default": "", "multiline": False}), - "key_5": ("STRING", {"default": "", "multiline": False}), - "key_6": ("STRING", {"default": "", "multiline": False}) - } - } - RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") - RETURN_NAMES = ("val_1", "val_2", "val_3", "val_4", "val_5", "val_6") - FUNCTION = "load_custom" - CATEGORY = "utils/json" - - def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3="", key_4="", key_5="", key_6=""): - data = read_json_data(json_path) - target_data = get_batch_item(data, sequence_number) - return ( - str(target_data.get(key_1, "")), str(target_data.get(key_2, "")), - str(target_data.get(key_3, "")), str(target_data.get(key_4, "")), - str(target_data.get(key_5, "")), str(target_data.get(key_6, "")) - ) - -# --- Mappings --- -NODE_CLASS_MAPPINGS = { - "JSONLoaderDynamic": JSONLoaderDynamic, - "JSONLoaderLoRA": JSONLoaderLoRA, - "JSONLoaderStandard": JSONLoaderStandard, - "JSONLoaderVACE": JSONLoaderVACE, - "JSONLoaderBatchLoRA": JSONLoaderBatchLoRA, - "JSONLoaderBatchI2V": JSONLoaderBatchI2V, - "JSONLoaderBatchVACE": JSONLoaderBatchVACE, - "JSONLoaderCustom1": JSONLoaderCustom1, - "JSONLoaderCustom3": JSONLoaderCustom3, - "JSONLoaderCustom6": JSONLoaderCustom6 -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "JSONLoaderDynamic": "JSON Loader (Dynamic)", - "JSONLoaderLoRA": "JSON Loader (LoRAs Only)", - "JSONLoaderStandard": "JSON Loader (Standard/I2V)", - "JSONLoaderVACE": "JSON Loader (VACE Full)", - "JSONLoaderBatchLoRA": "JSON Batch Loader (LoRAs)", - "JSONLoaderBatchI2V": "JSON Batch Loader (I2V)", - "JSONLoaderBatchVACE": "JSON Batch Loader (VACE)", - "JSONLoaderCustom1": "JSON Loader (Custom 1)", - "JSONLoaderCustom3": "JSON Loader (Custom 3)", - "JSONLoaderCustom6": "JSON Loader (Custom 6)" -} diff --git a/main.py b/main.py index bf34d82..072fe1e 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import json +import logging from pathlib import Path from nicegui import ui @@ -14,6 +15,18 @@ from tab_batch_ng import render_batch_processor from tab_timeline_ng import render_timeline_tab from tab_raw_ng import render_raw_editor from tab_comfy_ng import render_comfy_monitor +from tab_projects_ng import render_projects_tab +from db import ProjectDB +from api_routes import register_api_routes + +logger = logging.getLogger(__name__) + +# Single shared DB instance for both the UI and API routes +_shared_db: ProjectDB | None = None +try: + _shared_db = ProjectDB() +except Exception as e: + logger.warning(f"Failed to initialize ProjectDB: {e}") @ui.page('/') @@ -165,7 +178,13 @@ def index(): config=config, current_dir=Path(config.get('last_dir', str(Path.cwd()))), snippets=load_snippets(), + db_enabled=config.get('db_enabled', False), + current_project=config.get('current_project', ''), ) + + # Use the shared DB instance + state.db = _shared_db + dual_pane = {'active': False, 'state': None} # ------------------------------------------------------------------ @@ -187,6 +206,7 @@ def index(): ui.tab('batch', label='Batch Processor') ui.tab('timeline', label='Timeline') ui.tab('raw', label='Raw Editor') + ui.tab('projects', label='Projects') with ui.tab_panels(tabs, value='batch').classes('w-full'): with ui.tab_panel('batch'): @@ -195,6 +215,8 @@ def index(): render_timeline_tab(state) with ui.tab_panel('raw'): render_raw_editor(state) + with ui.tab_panel('projects'): + render_projects_tab(state) if state.show_comfy_monitor: ui.separator() @@ -490,4 +512,8 @@ def render_sidebar(state: AppState, dual_pane: dict): ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) +# Register REST API routes for ComfyUI connectivity (uses the shared DB instance) +if _shared_db is not None: + register_api_routes(_shared_db) + ui.run(title='AI Settings Manager', port=8080, reload=True) diff --git a/project_loader.py b/project_loader.py new file mode 100644 index 0000000..6420517 --- /dev/null +++ b/project_loader.py @@ -0,0 +1,215 @@ +import json +import logging +import urllib.parse +import urllib.request +import urllib.error +from typing import Any + +logger = logging.getLogger(__name__) + +MAX_DYNAMIC_OUTPUTS = 32 + + +class AnyType(str): + """Universal connector type that matches any ComfyUI type.""" + def __ne__(self, __value: object) -> bool: + return False + +any_type = AnyType("*") + + +try: + from server import PromptServer + from aiohttp import web +except ImportError: + PromptServer = None + + +def to_float(val: Any) -> float: + try: + return float(val) + except (ValueError, TypeError): + return 0.0 + +def to_int(val: Any) -> int: + try: + return int(float(val)) + except (ValueError, TypeError): + return 0 + + +def _fetch_json(url: str) -> dict: + """Fetch JSON from a URL using stdlib urllib. + + On error, returns a dict with an "error" key describing the failure. + """ + try: + with urllib.request.urlopen(url, timeout=5) as resp: + return json.loads(resp.read()) + except urllib.error.HTTPError as e: + # HTTPError is a subclass of URLError — must be caught first + body = "" + try: + raw = e.read() + detail = json.loads(raw) + body = detail.get("detail", str(raw, "utf-8", errors="replace")) + except Exception: + body = str(e) + logger.warning(f"HTTP {e.code} from {url}: {body}") + return {"error": "http_error", "status": e.code, "message": body} + except (urllib.error.URLError, OSError) as e: + reason = str(e.reason) if hasattr(e, "reason") else str(e) + logger.warning(f"Network error fetching {url}: {reason}") + return {"error": "network_error", "message": reason} + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON from {url}: {e}") + return {"error": "parse_error", "message": str(e)} + + +def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch sequence data from the NiceGUI REST API.""" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}" + return _fetch_json(url) + + +def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch keys/types from the NiceGUI REST API.""" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}" + return _fetch_json(url) + + +# --- ComfyUI-side proxy endpoints (for frontend JS) --- +if PromptServer is not None: + @PromptServer.instance.routes.get("/json_manager/list_projects") + async def list_projects_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + url = f"{manager_url.rstrip('/')}/api/projects" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_files") + async def list_project_files_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = urllib.parse.quote(request.query.get("project", ""), safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_sequences") + async def list_project_sequences_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = urllib.parse.quote(request.query.get("project", ""), safe='') + file_name = urllib.parse.quote(request.query.get("file", ""), safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/get_project_keys") + async def get_project_keys_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + file_name = request.query.get("file", "") + try: + seq = int(request.query.get("seq", "1")) + except (ValueError, TypeError): + seq = 1 + data = _fetch_keys(manager_url, project, file_name, seq) + if data.get("error") in ("http_error", "network_error", "parse_error"): + status = data.get("status", 502) + return web.json_response(data, status=status) + return web.json_response(data) + + + +# ========================================== +# 0. DYNAMIC NODE (Project-based) +# ========================================== + +class ProjectLoaderDynamic: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + "optional": { + "output_keys": ("STRING", {"default": ""}), + "output_types": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) + FUNCTION = "load_dynamic" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + def load_dynamic(self, manager_url, project_name, file_name, sequence_number, + output_keys="", output_types=""): + # Fetch keys metadata (includes total_sequences count) + keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number) + if keys_meta.get("error") in ("http_error", "network_error", "parse_error"): + msg = keys_meta.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch project keys: {msg}") + total_sequences = keys_meta.get("total_sequences", 0) + + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + msg = data.get("message", "Unknown error") + raise RuntimeError(f"Failed to fetch sequence data: {msg}") + + # Parse keys — try JSON array first, fall back to comma-split for compat + keys = [] + if output_keys: + try: + keys = json.loads(output_keys) + except (json.JSONDecodeError, TypeError): + keys = [k.strip() for k in output_keys.split(",") if k.strip()] + + # Parse types for coercion + types = [] + if output_types: + try: + types = json.loads(output_types) + except (json.JSONDecodeError, TypeError): + types = [t.strip() for t in output_types.split(",")] + + results = [] + for i, key in enumerate(keys): + val = data.get(key, "") + declared_type = types[i] if i < len(types) else "" + # Coerce based on declared output type when possible + if declared_type == "INT": + results.append(to_int(val)) + elif declared_type == "FLOAT": + results.append(to_float(val)) + elif isinstance(val, bool): + results.append(str(val).lower()) + elif isinstance(val, int): + results.append(val) + elif isinstance(val, float): + results.append(val) + else: + results.append(str(val)) + + while len(results) < MAX_DYNAMIC_OUTPUTS: + results.append("") + + return (total_sequences,) + tuple(results) + + +# --- Mappings --- +PROJECT_NODE_CLASS_MAPPINGS = { + "ProjectLoaderDynamic": ProjectLoaderDynamic, +} + +PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { + "ProjectLoaderDynamic": "Project Loader (Dynamic)", +} diff --git a/state.py b/state.py index e4aeab4..bef8818 100644 --- a/state.py +++ b/state.py @@ -17,6 +17,11 @@ class AppState: live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True + # Project DB fields + db: Any = None + current_project: str = "" + db_enabled: bool = False + # Set at runtime by main.py / tab_comfy_ng.py _render_main: Any = None _load_file: Callable | None = None @@ -29,4 +34,7 @@ class AppState: config=self.config, current_dir=self.current_dir, snippets=self.snippets, + db=self.db, + current_project=self.current_project, + db_enabled=self.db_enabled, ) diff --git a/tab_batch_ng.py b/tab_batch_ng.py index da47601..3324169 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -6,7 +6,7 @@ from nicegui import ui from state import AppState from utils import ( - DEFAULTS, save_json, load_json, + DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) from history_tree import HistoryTree @@ -161,6 +161,8 @@ def render_batch_processor(state: AppState): new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {}, KEY_PROMPT_HISTORY: []} save_json(new_path, new_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, new_path, new_data) ui.notify(f'Created {new_name}', type='positive') ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch) @@ -215,6 +217,8 @@ def render_batch_processor(state: AppState): batch_list.append(new_item) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) render_sequence_list.refresh() with ui.row().classes('q-mt-sm'): @@ -250,6 +254,8 @@ def render_batch_processor(state: AppState): batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0))) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Sorted by sequence number!', type='positive') render_sequence_list.refresh() @@ -289,6 +295,8 @@ def render_batch_processor(state: AppState): htree.commit(snapshot_payload, note=note) data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.restored_indicator = None commit_input.set_value('') ui.notify('Batch Saved & Snapshot Created!', type='positive') @@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, def commit(message=None): data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) if message: ui.notify(message, type='positive') refresh_list.refresh() @@ -448,7 +458,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # --- VACE Settings (full width) --- with ui.expansion('VACE Settings', icon='settings').classes('w-full'): - _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list) + _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list) # --- LoRA Settings --- with ui.expansion('LoRA Settings', icon='style').classes('w-full'): @@ -530,7 +540,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # VACE Settings sub-section # ====================================================================== -def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): +def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list): # VACE Schedule (needed early for both columns) sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1)) @@ -568,6 +578,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): shifted += 1 data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive') refresh_list.refresh() @@ -713,6 +725,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Updated {len(targets)} sequences', type='positive') if refresh_list: refresh_list.refresh() diff --git a/tab_projects_ng.py b/tab_projects_ng.py new file mode 100644 index 0000000..32494ac --- /dev/null +++ b/tab_projects_ng.py @@ -0,0 +1,165 @@ +import logging +from pathlib import Path + +from nicegui import ui + +from state import AppState +from db import ProjectDB +from utils import save_config, sync_to_db, KEY_BATCH_DATA + +logger = logging.getLogger(__name__) + + +def render_projects_tab(state: AppState): + """Render the Projects management tab.""" + + # --- DB toggle --- + def on_db_toggle(e): + state.db_enabled = e.value + state.config['db_enabled'] = e.value + save_config(state.current_dir, state.config.get('favorites', []), state.config) + render_project_content.refresh() + + ui.switch('Enable Project Database', value=state.db_enabled, + on_change=on_db_toggle).classes('q-mb-md') + + @ui.refreshable + def render_project_content(): + if not state.db_enabled: + ui.label('Project database is disabled. Enable it above to manage projects.').classes( + 'text-caption q-pa-md') + return + + if not state.db: + ui.label('Database not initialized.').classes('text-warning q-pa-md') + return + + # --- Create project form --- + with ui.card().classes('w-full q-pa-md q-mb-md'): + ui.label('Create New Project').classes('section-header') + name_input = ui.input('Project Name', placeholder='my_project').classes('w-full') + desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full') + + def create_project(): + name = name_input.value.strip() + if not name: + ui.notify('Please enter a project name', type='warning') + return + try: + state.db.create_project(name, str(state.current_dir), desc_input.value.strip()) + name_input.set_value('') + desc_input.set_value('') + ui.notify(f'Created project "{name}"', type='positive') + render_project_list.refresh() + except Exception as e: + ui.notify(f'Error: {e}', type='negative') + + ui.button('Create Project', icon='add', on_click=create_project).classes('w-full') + + # --- Active project indicator --- + if state.current_project: + ui.label(f'Active Project: {state.current_project}').classes( + 'text-bold text-primary q-pa-sm') + + # --- Project list --- + @ui.refreshable + def render_project_list(): + projects = state.db.list_projects() + if not projects: + ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md') + return + + for proj in projects: + is_active = proj['name'] == state.current_project + card_style = 'border-left: 3px solid var(--accent);' if is_active else '' + + with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style): + with ui.row().classes('w-full items-center'): + with ui.column().classes('col'): + ui.label(proj['name']).classes('text-bold') + if proj['description']: + ui.label(proj['description']).classes('text-caption') + ui.label(f'Path: {proj["folder_path"]}').classes('text-caption') + files = state.db.list_data_files(proj['id']) + ui.label(f'{len(files)} data file(s)').classes('text-caption') + + with ui.row().classes('q-gutter-xs'): + if not is_active: + def activate(name=proj['name']): + state.current_project = name + state.config['current_project'] = name + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify(f'Activated project "{name}"', type='positive') + render_project_list.refresh() + + ui.button('Activate', icon='check_circle', + on_click=activate).props('flat dense color=primary') + else: + def deactivate(): + state.current_project = '' + state.config['current_project'] = '' + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify('Deactivated project', type='info') + render_project_list.refresh() + + ui.button('Deactivate', icon='cancel', + on_click=deactivate).props('flat dense') + + def import_folder(pid=proj['id'], pname=proj['name']): + _import_folder(state, pid, pname, render_project_list) + + ui.button('Import Folder', icon='folder_open', + on_click=import_folder).props('flat dense') + + def delete_proj(name=proj['name']): + state.db.delete_project(name) + if state.current_project == name: + state.current_project = '' + state.config['current_project'] = '' + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify(f'Deleted project "{name}"', type='positive') + render_project_list.refresh() + + ui.button(icon='delete', + on_click=delete_proj).props('flat dense color=negative') + + render_project_list() + + render_project_content() + + +def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn): + """Bulk import all .json files from current directory into a project.""" + json_files = sorted(state.current_dir.glob('*.json')) + json_files = [f for f in json_files if f.name not in ( + '.editor_config.json', '.editor_snippets.json')] + + if not json_files: + ui.notify('No JSON files in current directory', type='warning') + return + + imported = 0 + skipped = 0 + for jf in json_files: + file_name = jf.stem + existing = state.db.get_data_file(project_id, file_name) + if existing: + skipped += 1 + continue + try: + state.db.import_json_file(project_id, jf) + imported += 1 + except Exception as e: + logger.warning(f"Failed to import {jf}: {e}") + + msg = f'Imported {imported} file(s)' + if skipped: + msg += f', skipped {skipped} existing' + ui.notify(msg, type='positive') + refresh_fn.refresh() diff --git a/tab_raw_ng.py b/tab_raw_ng.py index 39ec6f3..bdc4933 100644 --- a/tab_raw_ng.py +++ b/tab_raw_ng.py @@ -4,7 +4,7 @@ import json from nicegui import ui from state import AppState -from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY +from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY def render_raw_editor(state: AppState): @@ -52,6 +52,8 @@ def render_raw_editor(state: AppState): input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY] save_json(file_path, input_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, input_data) data.clear() data.update(input_data) diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index daaf460..d0467f4 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -5,7 +5,7 @@ from nicegui import ui from state import AppState from history_tree import HistoryTree -from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE +from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE def _delete_nodes(htree, data, file_path, node_ids): @@ -134,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn): def do_batch_delete(): current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) _delete_nodes(htree, data, file_path, current_valid) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.timeline_selected_nodes = set() ui.notify( f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', @@ -179,7 +181,7 @@ def _find_branch_for_node(htree, node_id): def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected): + selected, state=None): """Render branch-grouped node manager with restore, rename, delete, and preview.""" ui.label('Manage Version').classes('section-header') @@ -291,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ htree.nodes[sel_id]['note'] = rename_input.value data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Label updated', type='positive') refresh_fn() @@ -304,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ def delete_selected(): if sel_id in htree.nodes: _delete_nodes(htree, data, file_path, {sel_id}) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Node Deleted', type='positive') refresh_fn() @@ -377,7 +383,7 @@ def render_timeline_tab(state: AppState): _render_node_manager( all_nodes, htree, data, file_path, _restore_and_refresh, render_timeline.refresh, - selected) + selected, state=state) def _toggle_select(nid, checked): if checked: @@ -492,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState): htree.head_id = node['id'] data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) label = f"{node.get('note', 'Step')} ({node['id'][:4]})" state.restored_indicator = label ui.notify('Restored!', type='positive') diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..bea102f --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,369 @@ +import json +from pathlib import Path + +import pytest + +from db import ProjectDB +from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE + + +@pytest.fixture +def db(tmp_path): + """Create a fresh ProjectDB in a temp directory.""" + db_path = tmp_path / "test.db" + pdb = ProjectDB(db_path) + yield pdb + pdb.close() + + +# ------------------------------------------------------------------ +# Projects CRUD +# ------------------------------------------------------------------ + +class TestProjects: + def test_create_and_get(self, db): + pid = db.create_project("proj1", "/some/path", "A test project") + assert pid > 0 + proj = db.get_project("proj1") + assert proj is not None + assert proj["name"] == "proj1" + assert proj["folder_path"] == "/some/path" + assert proj["description"] == "A test project" + + def test_list_projects(self, db): + db.create_project("beta", "/b") + db.create_project("alpha", "/a") + projects = db.list_projects() + assert len(projects) == 2 + assert projects[0]["name"] == "alpha" + assert projects[1]["name"] == "beta" + + def test_get_nonexistent(self, db): + assert db.get_project("nope") is None + + def test_delete_project(self, db): + db.create_project("to_delete", "/x") + assert db.delete_project("to_delete") is True + assert db.get_project("to_delete") is None + + def test_delete_nonexistent(self, db): + assert db.delete_project("nope") is False + + def test_unique_name_constraint(self, db): + db.create_project("dup", "/a") + with pytest.raises(Exception): + db.create_project("dup", "/b") + + +# ------------------------------------------------------------------ +# Data files +# ------------------------------------------------------------------ + +class TestDataFiles: + def test_create_and_list(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"}) + assert df_id > 0 + files = db.list_data_files(pid) + assert len(files) == 1 + assert files[0]["name"] == "batch_i2v" + assert files[0]["data_type"] == "i2v" + + def test_get_data_file(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"}) + df = db.get_data_file(pid, "batch_i2v") + assert df is not None + assert df["top_level"] == {"key": "value"} + + def test_get_data_file_by_names(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + df = db.get_data_file_by_names("p1", "batch_i2v") + assert df is not None + assert df["name"] == "batch_i2v" + + def test_get_nonexistent_data_file(self, db): + pid = db.create_project("p1", "/p1") + assert db.get_data_file(pid, "nope") is None + + def test_unique_constraint(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + with pytest.raises(Exception): + db.create_data_file(pid, "batch_i2v", "vace") + + def test_cascade_delete(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "hello"}) + db.save_history_tree(df_id, {"nodes": {}}) + db.delete_project("p1") + assert db.get_data_file(pid, "batch_i2v") is None + + +# ------------------------------------------------------------------ +# Sequences +# ------------------------------------------------------------------ + +class TestSequences: + def test_upsert_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42}) + data = db.get_sequence(df_id, 1) + assert data == {"prompt": "hello", "seed": 42} + + def test_upsert_updates_existing(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "v1"}) + db.upsert_sequence(df_id, 1, {"prompt": "v2"}) + data = db.get_sequence(df_id, 1) + assert data["prompt"] == "v2" + + def test_list_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 3, {"a": 1}) + db.upsert_sequence(df_id, 1, {"b": 2}) + db.upsert_sequence(df_id, 2, {"c": 3}) + seqs = db.list_sequences(df_id) + assert seqs == [1, 2, 3] + + def test_get_nonexistent_sequence(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_sequence(df_id, 99) is None + + def test_get_sequence_keys(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, { + "prompt": "hello", + "seed": 42, + "cfg": 1.5, + "flag": True, + }) + keys, types = db.get_sequence_keys(df_id, 1) + assert "prompt" in keys + assert "seed" in keys + idx_prompt = keys.index("prompt") + idx_seed = keys.index("seed") + idx_cfg = keys.index("cfg") + idx_flag = keys.index("flag") + assert types[idx_prompt] == "STRING" + assert types[idx_seed] == "INT" + assert types[idx_cfg] == "FLOAT" + assert types[idx_flag] == "STRING" # bools -> STRING + + def test_get_sequence_keys_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + keys, types = db.get_sequence_keys(df_id, 99) + assert keys == [] + assert types == [] + + def test_delete_sequences_for_file(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.delete_sequences_for_file(df_id) + assert db.list_sequences(df_id) == [] + + def test_count_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.count_sequences(df_id) == 0 + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.upsert_sequence(df_id, 3, {"c": 3}) + assert db.count_sequences(df_id) == 3 + + def test_query_total_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + assert db.query_total_sequences("p1", "batch") == 2 + + def test_query_total_sequences_nonexistent(self, db): + assert db.query_total_sequences("nope", "nope") == 0 + + +# ------------------------------------------------------------------ +# History trees +# ------------------------------------------------------------------ + +class TestHistoryTrees: + def test_save_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"} + db.save_history_tree(df_id, tree) + result = db.get_history_tree(df_id) + assert result == tree + + def test_upsert_updates(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.save_history_tree(df_id, {"v": 1}) + db.save_history_tree(df_id, {"v": 2}) + result = db.get_history_tree(df_id) + assert result == {"v": 2} + + def test_get_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_history_tree(df_id) is None + + +# ------------------------------------------------------------------ +# Import +# ------------------------------------------------------------------ + +class TestImport: + def test_import_json_file(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch_prompt_i2v.json" + data = { + KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "hello", "seed": 42}, + {"sequence_number": 2, "prompt": "world", "seed": 99}, + ], + KEY_HISTORY_TREE: {"nodes": {}, "head_id": None}, + } + json_path.write_text(json.dumps(data)) + + df_id = db.import_json_file(pid, json_path, "i2v") + assert df_id > 0 + + seqs = db.list_sequences(df_id) + assert seqs == [1, 2] + + s1 = db.get_sequence(df_id, 1) + assert s1["prompt"] == "hello" + assert s1["seed"] == 42 + + tree = db.get_history_tree(df_id) + assert tree == {"nodes": {}, "head_id": None} + + def test_import_file_name_from_stem(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "my_batch.json" + json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]})) + db.import_json_file(pid, json_path) + df = db.get_data_file(pid, "my_batch") + assert df is not None + + def test_import_no_batch_data(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "simple.json" + json_path.write_text(json.dumps({"prompt": "flat file"})) + df_id = db.import_json_file(pid, json_path) + seqs = db.list_sequences(df_id) + assert seqs == [] + + def test_reimport_updates_existing(self, db, tmp_path): + """Re-importing the same file should update data, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + + # First import + data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]} + json_path.write_text(json.dumps(data_v1)) + df_id_1 = db.import_json_file(pid, json_path, "i2v") + + # Second import (same file, updated data) + data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]} + json_path.write_text(json.dumps(data_v2)) + df_id_2 = db.import_json_file(pid, json_path, "vace") + + # Should reuse the same data_file row + assert df_id_1 == df_id_2 + # Data type should be updated + df = db.get_data_file(pid, "batch") + assert df["data_type"] == "vace" + # Sequences should reflect v2 + seqs = db.list_sequences(df_id_2) + assert seqs == [1, 2] + s1 = db.get_sequence(df_id_2, 1) + assert s1["prompt"] == "v2" + + def test_import_skips_non_dict_batch_items(self, db, tmp_path): + """Non-dict elements in batch_data should be silently skipped, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "mixed.json" + data = {KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "valid"}, + "not a dict", + 42, + None, + {"sequence_number": 3, "prompt": "also valid"}, + ]} + json_path.write_text(json.dumps(data)) + df_id = db.import_json_file(pid, json_path) + + seqs = db.list_sequences(df_id) + assert seqs == [1, 3] + + def test_import_atomic_on_error(self, db, tmp_path): + """If import fails partway, no partial data should be committed.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]} + json_path.write_text(json.dumps(data)) + db.import_json_file(pid, json_path) + + # Now try to import with bad data that will cause an error + # (overwrite the file with invalid sequence_number that causes int() to fail) + bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]} + json_path.write_text(json.dumps(bad_data)) + with pytest.raises(ValueError): + db.import_json_file(pid, json_path) + + # Original data should still be intact (rollback worked) + df = db.get_data_file(pid, "batch") + assert df is not None + s1 = db.get_sequence(df["id"], 1) + assert s1["prompt"] == "hello" + + +# ------------------------------------------------------------------ +# Query helpers +# ------------------------------------------------------------------ + +class TestQueryHelpers: + def test_query_sequence_data(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + result = db.query_sequence_data("myproject", "batch_i2v", 1) + assert result == {"prompt": "test", "seed": 7} + + def test_query_sequence_data_not_found(self, db): + assert db.query_sequence_data("nope", "nope", 1) is None + + def test_query_sequence_keys(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1) + assert "prompt" in keys + assert "seed" in keys + + def test_list_project_files(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "file_a", "i2v") + db.create_data_file(pid, "file_b", "vace") + files = db.list_project_files("p1") + assert len(files) == 2 + + def test_list_project_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {}) + db.upsert_sequence(df_id, 2, {}) + seqs = db.list_project_sequences("p1", "batch") + assert seqs == [1, 2] diff --git a/tests/test_json_loader.py b/tests/test_json_loader.py deleted file mode 100644 index 2729e4d..0000000 --- a/tests/test_json_loader.py +++ /dev/null @@ -1,165 +0,0 @@ -import json -import os - -import pytest - -from json_loader import ( - to_float, to_int, get_batch_item, read_json_data, - JSONLoaderDynamic, MAX_DYNAMIC_OUTPUTS, -) - - -class TestToFloat: - def test_valid(self): - assert to_float("3.14") == 3.14 - assert to_float(5) == 5.0 - - def test_invalid(self): - assert to_float("abc") == 0.0 - - def test_none(self): - assert to_float(None) == 0.0 - - -class TestToInt: - def test_valid(self): - assert to_int("7") == 7 - assert to_int(3.9) == 3 - - def test_invalid(self): - assert to_int("xyz") == 0 - - def test_none(self): - assert to_int(None) == 0 - - -class TestGetBatchItem: - def test_lookup_by_sequence_number_field(self): - data = {"batch_data": [ - {"sequence_number": 1, "a": "first"}, - {"sequence_number": 5, "a": "fifth"}, - {"sequence_number": 3, "a": "third"}, - ]} - assert get_batch_item(data, 5) == {"sequence_number": 5, "a": "fifth"} - assert get_batch_item(data, 3) == {"sequence_number": 3, "a": "third"} - - def test_fallback_to_index(self): - data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]} - assert get_batch_item(data, 2) == {"a": 2} - - def test_clamp_high(self): - data = {"batch_data": [{"a": 1}, {"a": 2}]} - assert get_batch_item(data, 99) == {"a": 2} - - def test_clamp_low(self): - data = {"batch_data": [{"a": 1}, {"a": 2}]} - assert get_batch_item(data, 0) == {"a": 1} - - def test_no_batch_data(self): - data = {"key": "val"} - assert get_batch_item(data, 1) == data - - -class TestReadJsonData: - def test_missing_file(self, tmp_path): - assert read_json_data(str(tmp_path / "nope.json")) == {} - - def test_invalid_json(self, tmp_path): - p = tmp_path / "bad.json" - p.write_text("{broken") - assert read_json_data(str(p)) == {} - - def test_non_dict_json(self, tmp_path): - p = tmp_path / "list.json" - p.write_text(json.dumps([1, 2, 3])) - assert read_json_data(str(p)) == {} - - def test_valid(self, tmp_path): - p = tmp_path / "ok.json" - p.write_text(json.dumps({"key": "val"})) - assert read_json_data(str(p)) == {"key": "val"} - - -class TestJSONLoaderDynamic: - def _make_json(self, tmp_path, data): - p = tmp_path / "test.json" - p.write_text(json.dumps(data)) - return str(p) - - def test_known_keys(self, tmp_path): - path = self._make_json(tmp_path, {"name": "alice", "age": 30, "score": 9.5}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="name,age,score") - assert result[0] == "alice" - assert result[1] == 30 - assert result[2] == 9.5 - - def test_empty_output_keys(self, tmp_path): - path = self._make_json(tmp_path, {"name": "alice"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert all(v == "" for v in result) - - def test_pads_to_max(self, tmp_path): - path = self._make_json(tmp_path, {"a": "1", "b": "2"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="a,b") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert result[0] == "1" - assert result[1] == "2" - assert all(v == "" for v in result[2:]) - - def test_type_preservation_int(self, tmp_path): - path = self._make_json(tmp_path, {"count": 42}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="count") - assert result[0] == 42 - assert isinstance(result[0], int) - - def test_type_preservation_float(self, tmp_path): - path = self._make_json(tmp_path, {"rate": 3.14}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="rate") - assert result[0] == 3.14 - assert isinstance(result[0], float) - - def test_type_preservation_str(self, tmp_path): - path = self._make_json(tmp_path, {"label": "hello"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="label") - assert result[0] == "hello" - assert isinstance(result[0], str) - - def test_bool_becomes_string(self, tmp_path): - path = self._make_json(tmp_path, {"flag": True, "off": False}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="flag,off") - assert result[0] == "true" - assert result[1] == "false" - assert isinstance(result[0], str) - - def test_missing_key_returns_empty_string(self, tmp_path): - path = self._make_json(tmp_path, {"a": "1"}) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 1, output_keys="a,nonexistent") - assert result[0] == "1" - assert result[1] == "" - - def test_missing_file_returns_all_empty(self, tmp_path): - loader = JSONLoaderDynamic() - result = loader.load_dynamic(str(tmp_path / "nope.json"), 1, output_keys="a,b") - assert len(result) == MAX_DYNAMIC_OUTPUTS - assert result[0] == "" - assert result[1] == "" - - def test_batch_data(self, tmp_path): - path = self._make_json(tmp_path, { - "batch_data": [ - {"sequence_number": 1, "x": "first"}, - {"sequence_number": 2, "x": "second"}, - ] - }) - loader = JSONLoaderDynamic() - result = loader.load_dynamic(path, 2, output_keys="x") - assert result[0] == "second" diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py new file mode 100644 index 0000000..dab80b2 --- /dev/null +++ b/tests/test_project_loader.py @@ -0,0 +1,211 @@ +import json +from unittest.mock import patch, MagicMock +from io import BytesIO + +import pytest + +from project_loader import ( + ProjectLoaderDynamic, + _fetch_json, + _fetch_data, + _fetch_keys, + MAX_DYNAMIC_OUTPUTS, +) + + +def _mock_urlopen(data: dict): + """Create a mock context manager for urllib.request.urlopen.""" + response = MagicMock() + response.read.return_value = json.dumps(data).encode() + response.__enter__ = lambda s: s + response.__exit__ = MagicMock(return_value=False) + return response + + +class TestFetchHelpers: + def test_fetch_json_success(self): + data = {"key": "value"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)): + result = _fetch_json("http://example.com/api") + assert result == data + + def test_fetch_json_network_error(self): + with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): + result = _fetch_json("http://example.com/api") + assert result["error"] == "network_error" + assert "connection refused" in result["message"] + + def test_fetch_json_http_error(self): + import urllib.error + err = urllib.error.HTTPError( + "http://example.com/api", 404, "Not Found", {}, + BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode()) + ) + with patch("project_loader.urllib.request.urlopen", side_effect=err): + result = _fetch_json("http://example.com/api") + assert result["error"] == "http_error" + assert result["status"] == 404 + assert "not found" in result["message"].lower() + + def test_fetch_data_builds_url(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url + + def test_fetch_keys_builds_url(self): + data = {"keys": ["prompt"], "types": ["STRING"]} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url + + def test_fetch_data_strips_trailing_slash(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080/", "proj1", "file1", 1) + called_url = mock.call_args[0][0] + assert "//api" not in called_url + + def test_fetch_data_encodes_special_chars(self): + """Project/file names with spaces or special chars should be percent-encoded.""" + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080", "my project", "batch file", 1) + called_url = mock.call_args[0][0] + assert "my%20project" in called_url + assert "batch%20file" in called_url + assert " " not in called_url.split("?")[0] # no raw spaces in path + + +class TestProjectLoaderDynamic: + def _keys_meta(self, total=5): + return {"keys": [], "types": [], "total_sequences": total} + + def test_load_dynamic_with_keys(self): + data = {"prompt": "hello", "seed": 42, "cfg": 1.5} + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="prompt,seed,cfg" + ) + assert result[0] == 5 # total_sequences + assert result[1] == "hello" + assert result[2] == 42 + assert result[3] == 1.5 + assert len(result) == MAX_DYNAMIC_OUTPUTS + 1 + + def test_load_dynamic_with_json_encoded_keys(self): + """JSON-encoded output_keys should be parsed correctly.""" + import json as _json + data = {"my,key": "comma_val", "normal": "ok"} + node = ProjectLoaderDynamic() + keys_json = _json.dumps(["my,key", "normal"]) + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json + ) + assert result[1] == "comma_val" + assert result[2] == "ok" + + def test_load_dynamic_type_coercion(self): + """output_types should coerce values to declared types.""" + import json as _json + data = {"seed": "42", "cfg": "1.5", "prompt": "hello"} + node = ProjectLoaderDynamic() + keys_json = _json.dumps(["seed", "cfg", "prompt"]) + types_json = _json.dumps(["INT", "FLOAT", "STRING"]) + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json, output_types=types_json + ) + assert result[1] == 42 # string "42" coerced to int + assert result[2] == 1.5 # string "1.5" coerced to float + assert result[3] == "hello" # string stays string + + def test_load_dynamic_empty_keys(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + # Slot 0 is total_sequences (INT), rest are empty strings + assert result[0] == 5 + assert all(v == "" for v in result[1:]) + + def test_load_dynamic_missing_key(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="nonexistent" + ) + assert result[1] == "" + + def test_load_dynamic_bool_becomes_string(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value={"flag": True}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="flag" + ) + assert result[1] == "true" + + def test_load_dynamic_returns_total_sequences(self): + """total_sequences should be the first output from keys metadata.""" + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}): + with patch("project_loader._fetch_data", return_value={}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + assert result[0] == 42 + + def test_load_dynamic_raises_on_network_error(self): + """Network errors from _fetch_keys should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_keys", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch project keys"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) + + def test_load_dynamic_raises_on_data_fetch_error(self): + """Network errors from _fetch_data should raise RuntimeError.""" + node = ProjectLoaderDynamic() + error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"} + with patch("project_loader._fetch_keys", return_value=self._keys_meta()): + with patch("project_loader._fetch_data", return_value=error_resp): + with pytest.raises(RuntimeError, match="Failed to fetch sequence data"): + node.load_dynamic("http://localhost:8080", "proj1", "batch", 1) + + def test_input_types_has_manager_url(self): + inputs = ProjectLoaderDynamic.INPUT_TYPES() + assert "manager_url" in inputs["required"] + assert "project_name" in inputs["required"] + assert "file_name" in inputs["required"] + assert "sequence_number" in inputs["required"] + + def test_category(self): + assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" + + +class TestNodeMappings: + def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1 diff --git a/utils.py b/utils.py index 2e49007..af80c60 100644 --- a/utils.py +++ b/utils.py @@ -160,6 +160,80 @@ def get_file_mtime(path: str | Path) -> float: return path.stat().st_mtime return 0 +def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: + """Dual-write helper: sync JSON data to the project database. + + Resolves (or creates) the data_file, upserts all sequences from batch_data, + and saves the history_tree. All writes happen in a single transaction. + """ + if not db or not project_name: + return + try: + proj = db.get_project(project_name) + if not proj: + return + file_name = Path(file_path).stem + + # Use a single transaction for atomicity + db.conn.execute("BEGIN IMMEDIATE") + try: + df = db.get_data_file(proj["id"], file_name) + top_level = {k: v for k, v in data.items() + if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + if not df: + now = __import__('time').time() + cur = db.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (proj["id"], file_name, "generic", json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid + else: + df_id = df["id"] + # Update top_level metadata + now = __import__('time').time() + db.conn.execute( + "UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?", + (json.dumps(top_level), now, df_id), + ) + + # Sync sequences + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + for item in batch_data: + if not isinstance(item, dict): + continue + seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) + now = __import__('time').time() + db.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?)", + (df_id, seq_num, json.dumps(item), now), + ) + + # Sync history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = __import__('time').time() + db.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + db.conn.execute("COMMIT") + except Exception: + try: + db.conn.execute("ROLLBACK") + except Exception: + pass + raise + except Exception as e: + logger.warning(f"sync_to_db failed: {e}") + + def generate_templates(current_dir: Path) -> None: """Creates batch template files if folder is empty.""" first = DEFAULTS.copy() diff --git a/web/json_dynamic.js b/web/json_dynamic.js deleted file mode 100644 index 81e11f1..0000000 --- a/web/json_dynamic.js +++ /dev/null @@ -1,140 +0,0 @@ -import { app } from "../../scripts/app.js"; -import { api } from "../../scripts/api.js"; - -app.registerExtension({ - name: "json.manager.dynamic", - - async beforeRegisterNodeDef(nodeType, nodeData, app) { - if (nodeData.name !== "JSONLoaderDynamic") return; - - const origOnNodeCreated = nodeType.prototype.onNodeCreated; - nodeType.prototype.onNodeCreated = function () { - origOnNodeCreated?.apply(this, arguments); - - // Hide internal widgets (managed by JS) - for (const name of ["output_keys", "output_types"]) { - const w = this.widgets?.find(w => w.name === name); - if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } - } - - // Remove all 32 default outputs from Python RETURN_TYPES - while (this.outputs.length > 0) { - this.removeOutput(0); - } - - // Add Refresh button - this.addWidget("button", "Refresh Outputs", null, () => { - this.refreshDynamicOutputs(); - }); - - this.setSize(this.computeSize()); - }; - - nodeType.prototype.refreshDynamicOutputs = async function () { - const pathWidget = this.widgets?.find(w => w.name === "json_path"); - const seqWidget = this.widgets?.find(w => w.name === "sequence_number"); - if (!pathWidget?.value) return; - - try { - const resp = await api.fetchApi( - `/json_manager/get_keys?path=${encodeURIComponent(pathWidget.value)}&sequence_number=${seqWidget?.value || 1}` - ); - const { keys, types } = await resp.json(); - - // Store keys and types in hidden widgets for persistence - const okWidget = this.widgets?.find(w => w.name === "output_keys"); - if (okWidget) okWidget.value = keys.join(","); - const otWidget = this.widgets?.find(w => w.name === "output_types"); - if (otWidget) otWidget.value = types.join(","); - - // Build a map of current output names to slot indices - const oldSlots = {}; - for (let i = 0; i < this.outputs.length; i++) { - oldSlots[this.outputs[i].name] = i; - } - - // Build new outputs, reusing existing slots to preserve links - const newOutputs = []; - for (let k = 0; k < keys.length; k++) { - const key = keys[k]; - const type = types[k] || "*"; - if (key in oldSlots) { - // Reuse existing slot object (keeps links intact) - const slot = this.outputs[oldSlots[key]]; - slot.type = type; - newOutputs.push(slot); - delete oldSlots[key]; - } else { - // New key — create a fresh slot - newOutputs.push({ name: key, type: type, links: null }); - } - } - - // Disconnect links on slots that are being removed - for (const name in oldSlots) { - const idx = oldSlots[name]; - if (this.outputs[idx]?.links?.length) { - for (const linkId of [...this.outputs[idx].links]) { - this.graph?.removeLink(linkId); - } - } - } - - // Reassign the outputs array and fix link slot indices - this.outputs = newOutputs; - // Update link origin_slot to match new positions - if (this.graph) { - for (let i = 0; i < this.outputs.length; i++) { - const links = this.outputs[i].links; - if (!links) continue; - for (const linkId of links) { - const link = this.graph.links[linkId]; - if (link) link.origin_slot = i; - } - } - } - - this.setSize(this.computeSize()); - app.graph.setDirtyCanvas(true, true); - } catch (e) { - console.error("[JSONLoaderDynamic] Refresh failed:", e); - } - }; - - // Restore state on workflow load - const origOnConfigure = nodeType.prototype.onConfigure; - nodeType.prototype.onConfigure = function (info) { - origOnConfigure?.apply(this, arguments); - - // Hide internal widgets - for (const name of ["output_keys", "output_types"]) { - const w = this.widgets?.find(w => w.name === name); - if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } - } - - const okWidget = this.widgets?.find(w => w.name === "output_keys"); - const otWidget = this.widgets?.find(w => w.name === "output_types"); - - const keys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - const types = otWidget?.value - ? otWidget.value.split(",") - : []; - - // On load, LiteGraph already restored serialized outputs with links. - // Rename and set types to match stored state (preserves links). - for (let i = 0; i < this.outputs.length && i < keys.length; i++) { - this.outputs[i].name = keys[i].trim(); - if (types[i]) this.outputs[i].type = types[i]; - } - - // Remove any extra outputs beyond the key count - while (this.outputs.length > keys.length) { - this.removeOutput(this.outputs.length - 1); - } - - this.setSize(this.computeSize()); - }; - }, -}); diff --git a/web/project_dynamic.js b/web/project_dynamic.js new file mode 100644 index 0000000..7f2de6e --- /dev/null +++ b/web/project_dynamic.js @@ -0,0 +1,255 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.dynamic", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectLoaderDynamic") return; + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + + // Hide internal widgets (managed by JS) + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + // Do NOT remove default outputs synchronously here. + // During graph loading, ComfyUI creates all nodes (firing onNodeCreated) + // before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve + // links to our outputs during their configure step. If we remove outputs + // here, those nodes find no output slot and error out. + // + // Instead, defer cleanup: for loaded workflows onConfigure sets _configured + // before this runs; for new nodes the defaults are cleaned up. + this._configured = false; + + // Add Refresh button + this.addWidget("button", "Refresh Outputs", null, () => { + this.refreshDynamicOutputs(); + }); + + // Auto-refresh with 500ms debounce on widget changes + this._refreshTimer = null; + const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"]; + for (const widgetName of autoRefreshWidgets) { + const w = this.widgets?.find(w => w.name === widgetName); + if (w) { + const origCallback = w.callback; + const node = this; + w.callback = function (...args) { + origCallback?.apply(this, args); + clearTimeout(node._refreshTimer); + node._refreshTimer = setTimeout(() => { + node.refreshDynamicOutputs(); + }, 500); + }; + } + } + + queueMicrotask(() => { + if (!this._configured) { + // New node (not loading) — remove the Python default outputs + // and add only the fixed total_sequences slot + while (this.outputs.length > 0) { + this.removeOutput(0); + } + this.addOutput("total_sequences", "INT"); + this.setSize(this.computeSize()); + app.graph?.setDirtyCanvas(true, true); + } + }); + }; + + nodeType.prototype._setStatus = function (status, message) { + const baseTitle = "Project Loader (Dynamic)"; + if (status === "ok") { + this.title = baseTitle; + this.color = undefined; + this.bgcolor = undefined; + } else if (status === "error") { + this.title = baseTitle + " - ERROR"; + this.color = "#ff4444"; + this.bgcolor = "#331111"; + if (message) this.title = baseTitle + ": " + message; + } else if (status === "loading") { + this.title = baseTitle + " - Loading..."; + } + app.graph?.setDirtyCanvas(true, true); + }; + + nodeType.prototype.refreshDynamicOutputs = async function () { + const urlWidget = this.widgets?.find(w => w.name === "manager_url"); + const projectWidget = this.widgets?.find(w => w.name === "project_name"); + const fileWidget = this.widgets?.find(w => w.name === "file_name"); + const seqWidget = this.widgets?.find(w => w.name === "sequence_number"); + + if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; + + this._setStatus("loading"); + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` + ); + + if (!resp.ok) { + let errorMsg = `HTTP ${resp.status}`; + try { + const errData = await resp.json(); + if (errData.message) errorMsg = errData.message; + } catch (_) {} + this._setStatus("error", errorMsg); + return; + } + + const data = await resp.json(); + const keys = data.keys; + const types = data.types; + + // If the API returned an error or missing data, keep existing outputs and links intact + if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { + const errMsg = data.error ? data.message || data.error : "Missing keys/types"; + this._setStatus("error", errMsg); + return; + } + + // Store keys and types in hidden widgets for persistence (comma-separated) + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + if (okWidget) okWidget.value = keys.join(","); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + if (otWidget) otWidget.value = types.join(","); + + // Slot 0 is always total_sequences (INT) — ensure it exists + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + } + this.outputs[0].type = "INT"; + + // Build a map of current dynamic output names to slot indices (skip slot 0) + const oldSlots = {}; + for (let i = 1; i < this.outputs.length; i++) { + oldSlots[this.outputs[i].name] = i; + } + + // Build new dynamic outputs, reusing existing slots to preserve links + const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0 + for (let k = 0; k < keys.length; k++) { + const key = keys[k]; + const type = types[k] || "*"; + if (key in oldSlots) { + const slot = this.outputs[oldSlots[key]]; + slot.type = type; + slot.label = key; + newOutputs.push(slot); + delete oldSlots[key]; + } else { + newOutputs.push({ name: key, label: key, type: type, links: null }); + } + } + + // Disconnect links on slots that are being removed + for (const name in oldSlots) { + const idx = oldSlots[name]; + if (this.outputs[idx]?.links?.length) { + for (const linkId of [...this.outputs[idx].links]) { + this.graph?.removeLink(linkId); + } + } + } + + // Reassign the outputs array and fix link slot indices + this.outputs = newOutputs; + if (this.graph) { + for (let i = 0; i < this.outputs.length; i++) { + const links = this.outputs[i].links; + if (!links) continue; + for (const linkId of links) { + const link = this.graph.links[linkId]; + if (link) link.origin_slot = i; + } + } + } + + this._setStatus("ok"); + this.setSize(this.computeSize()); + app.graph?.setDirtyCanvas(true, true); + } catch (e) { + console.error("[ProjectLoaderDynamic] Refresh failed:", e); + this._setStatus("error", "Server unreachable"); + } + }; + + // Restore state on workflow load + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + this._configured = true; + + // Hide internal widgets + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + + const keys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + const types = otWidget?.value + ? otWidget.value.split(",") + : []; + + // Ensure slot 0 is total_sequences (INT) + if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") { + this.outputs.unshift({ name: "total_sequences", type: "INT", links: null }); + const node = this; + queueMicrotask(() => { + if (!node.graph) return; + for (const output of node.outputs) { + output.links = null; + } + for (const linkId in node.graph.links) { + const link = node.graph.links[linkId]; + if (!link || link.origin_id !== node.id) continue; + link.origin_slot += 1; + const output = node.outputs[link.origin_slot]; + if (output) { + if (!output.links) output.links = []; + output.links.push(link.id); + } + } + app.graph?.setDirtyCanvas(true, true); + }); + } + this.outputs[0].type = "INT"; + this.outputs[0].name = "total_sequences"; + + if (keys.length > 0) { + for (let i = 0; i < keys.length; i++) { + const slotIdx = i + 1; + if (slotIdx < this.outputs.length) { + this.outputs[slotIdx].name = keys[i].trim(); + this.outputs[slotIdx].label = keys[i].trim(); + if (types[i]) this.outputs[slotIdx].type = types[i]; + } + } + while (this.outputs.length > keys.length + 1) { + this.removeOutput(this.outputs.length - 1); + } + } else if (this.outputs.length > 1) { + // Widget values empty but serialized dynamic outputs exist — sync widgets + const dynamicOutputs = this.outputs.slice(1); + if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(","); + if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(","); + } + + this.setSize(this.computeSize()); + }; + }, +});