From c15bec98ce664621554283e4da99c7883ba4b8dc Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:12:05 +0100 Subject: [PATCH] Add SQLite project database + ComfyUI connector nodes - db.py: ProjectDB class with SQLite schema (projects, data_files, sequences, history_trees), WAL mode, CRUD, import, and query helpers - api_routes.py: REST API endpoints on NiceGUI/FastAPI for ComfyUI to query project data over the network - project_loader.py: ComfyUI nodes (ProjectLoaderDynamic, Standard, VACE, LoRA) that fetch data from NiceGUI REST API via HTTP - web/project_dynamic.js: Frontend JS for dynamic project loader node - tab_projects_ng.py: Projects management tab in NiceGUI UI - state.py: Added db, current_project, db_enabled fields - main.py: DB init, API route registration, projects tab - utils.py: sync_to_db() dual-write helper - tab_batch_ng.py, tab_raw_ng.py, tab_timeline_ng.py: dual-write sync calls after save_json when project DB is enabled - __init__.py: Merged project node class mappings - tests/test_db.py: 30 tests for database layer - tests/test_project_loader.py: 17 tests for ComfyUI connector nodes Co-Authored-By: Claude Opus 4.6 --- __init__.py | 4 + api_routes.py | 67 ++++++++ db.py | 285 ++++++++++++++++++++++++++++++++++ main.py | 26 ++++ project_loader.py | 255 +++++++++++++++++++++++++++++++ state.py | 5 + tab_batch_ng.py | 16 +- tab_projects_ng.py | 161 ++++++++++++++++++++ tab_raw_ng.py | 4 +- tab_timeline_ng.py | 14 +- tests/test_db.py | 286 +++++++++++++++++++++++++++++++++++ tests/test_project_loader.py | 201 ++++++++++++++++++++++++ utils.py | 37 +++++ web/project_dynamic.js | 139 +++++++++++++++++ 14 files changed, 1495 insertions(+), 5 deletions(-) create mode 100644 api_routes.py create mode 100644 db.py create mode 100644 project_loader.py create mode 100644 tab_projects_ng.py create mode 100644 tests/test_db.py create mode 100644 tests/test_project_loader.py create mode 100644 web/project_dynamic.js diff --git a/__init__.py b/__init__.py index 43198c8..5e657ee 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,8 @@ from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS +from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + +NODE_CLASS_MAPPINGS.update(PROJECT_NODE_CLASS_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) WEB_DIRECTORY = "./web" diff --git a/api_routes.py b/api_routes.py new file mode 100644 index 0000000..36e84cd --- /dev/null +++ b/api_routes.py @@ -0,0 +1,67 @@ +"""REST API endpoints for ComfyUI to query project data from SQLite. + +All endpoints are read-only. Mounted on the NiceGUI/FastAPI server. +""" + +import logging +from typing import Any + +from fastapi import HTTPException, Query +from nicegui import app + +from db import ProjectDB + +logger = logging.getLogger(__name__) + +# The DB instance is set by register_api_routes() +_db: ProjectDB | None = None + + +def register_api_routes(db: ProjectDB) -> None: + """Register all REST API routes with the NiceGUI/FastAPI app.""" + global _db + _db = db + + app.add_api_route("/api/projects", _list_projects, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"]) + app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"]) + + +def _get_db() -> ProjectDB: + if _db is None: + raise HTTPException(status_code=503, detail="Database not initialized") + return _db + + +async def _list_projects() -> dict[str, Any]: + db = _get_db() + projects = db.list_projects() + return {"projects": [p["name"] for p in projects]} + + +async def _list_files(name: str) -> dict[str, Any]: + db = _get_db() + files = db.list_project_files(name) + return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]} + + +async def _list_sequences(name: str, file_name: str) -> dict[str, Any]: + db = _get_db() + seqs = db.list_project_sequences(name, file_name) + return {"sequences": seqs} + + +async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + data = db.query_sequence_data(name, file_name, seq) + if data is None: + raise HTTPException(status_code=404, detail="Sequence not found") + return data + + +async def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + db = _get_db() + keys, types = db.query_sequence_keys(name, file_name, seq) + return {"keys": keys, "types": types} diff --git a/db.py b/db.py new file mode 100644 index 0000000..11efc7d --- /dev/null +++ b/db.py @@ -0,0 +1,285 @@ +import json +import logging +import sqlite3 +import time +from pathlib import Path +from typing import Any + +from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE + +logger = logging.getLogger(__name__) + +DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db" + +SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS projects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + folder_path TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + created_at REAL NOT NULL, + updated_at REAL NOT NULL +); + +CREATE TABLE IF NOT EXISTS data_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + name TEXT NOT NULL, + data_type TEXT NOT NULL DEFAULT 'generic', + top_level TEXT NOT NULL DEFAULT '{}', + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + UNIQUE(project_id, name) +); + +CREATE TABLE IF NOT EXISTS sequences ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE, + sequence_number INTEGER NOT NULL, + data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL, + UNIQUE(data_file_id, sequence_number) +); + +CREATE TABLE IF NOT EXISTS history_trees ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE, + tree_data TEXT NOT NULL DEFAULT '{}', + updated_at REAL NOT NULL +); +""" + + +class ProjectDB: + """SQLite database for project-based data management.""" + + def __init__(self, db_path: str | Path | None = None): + self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self.conn.execute("PRAGMA journal_mode=WAL") + self.conn.execute("PRAGMA foreign_keys=ON") + self.conn.executescript(SCHEMA_SQL) + self.conn.commit() + + def close(self): + self.conn.close() + + # ------------------------------------------------------------------ + # Projects CRUD + # ------------------------------------------------------------------ + + def create_project(self, name: str, folder_path: str, description: str = "") -> int: + now = time.time() + cur = self.conn.execute( + "INSERT INTO projects (name, folder_path, description, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?)", + (name, folder_path, description, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_projects(self) -> list[dict]: + rows = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects ORDER BY name" + ).fetchall() + return [dict(r) for r in rows] + + def get_project(self, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, name, folder_path, description, created_at, updated_at " + "FROM projects WHERE name = ?", + (name,), + ).fetchone() + return dict(row) if row else None + + def delete_project(self, name: str) -> bool: + cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,)) + self.conn.commit() + return cur.rowcount > 0 + + # ------------------------------------------------------------------ + # Data files + # ------------------------------------------------------------------ + + def create_data_file( + self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None + ) -> int: + now = time.time() + tl = json.dumps(top_level or {}) + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, name, data_type, tl, now, now), + ) + self.conn.commit() + return cur.lastrowid + + def list_data_files(self, project_id: int) -> list[dict]: + rows = self.conn.execute( + "SELECT id, project_id, name, data_type, created_at, updated_at " + "FROM data_files WHERE project_id = ? ORDER BY name", + (project_id,), + ).fetchall() + return [dict(r) for r in rows] + + def get_data_file(self, project_id: int, name: str) -> dict | None: + row = self.conn.execute( + "SELECT id, project_id, name, data_type, top_level, created_at, updated_at " + "FROM data_files WHERE project_id = ? AND name = ?", + (project_id, name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None: + row = self.conn.execute( + "SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, " + "df.created_at, df.updated_at " + "FROM data_files df JOIN projects p ON df.project_id = p.id " + "WHERE p.name = ? AND df.name = ?", + (project_name, file_name), + ).fetchone() + if row is None: + return None + d = dict(row) + d["top_level"] = json.loads(d["top_level"]) + return d + + # ------------------------------------------------------------------ + # Sequences + # ------------------------------------------------------------------ + + def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (data_file_id, sequence_number, json.dumps(data), now), + ) + self.conn.commit() + + def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None: + row = self.conn.execute( + "SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?", + (data_file_id, sequence_number), + ).fetchone() + return json.loads(row["data"]) if row else None + + def list_sequences(self, data_file_id: int) -> list[int]: + rows = self.conn.execute( + "SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number", + (data_file_id,), + ).fetchall() + return [r["sequence_number"] for r in rows] + + def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]: + """Returns (keys, types) for a sequence's data dict.""" + data = self.get_sequence(data_file_id, sequence_number) + if not data: + return [], [] + keys = [] + types = [] + for k, v in data.items(): + keys.append(k) + if isinstance(v, bool): + types.append("STRING") + elif isinstance(v, int): + types.append("INT") + elif isinstance(v, float): + types.append("FLOAT") + else: + types.append("STRING") + return keys, types + + def delete_sequences_for_file(self, data_file_id: int) -> None: + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,)) + self.conn.commit() + + # ------------------------------------------------------------------ + # History trees + # ------------------------------------------------------------------ + + def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (data_file_id, json.dumps(tree_data), now), + ) + self.conn.commit() + + def get_history_tree(self, data_file_id: int) -> dict | None: + row = self.conn.execute( + "SELECT tree_data FROM history_trees WHERE data_file_id = ?", + (data_file_id,), + ).fetchone() + return json.loads(row["tree_data"]) if row else None + + # ------------------------------------------------------------------ + # Import + # ------------------------------------------------------------------ + + def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: + """Import a JSON file into the database, splitting batch_data into sequences.""" + json_path = Path(json_path) + data, _ = load_json(json_path) + file_name = json_path.stem + + # Extract top-level keys that aren't batch_data or history_tree + top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + + df_id = self.create_data_file(project_id, file_name, data_type, top_level) + + # Import sequences from batch_data + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + for item in batch_data: + seq_num = int(item.get("sequence_number", 0)) + self.upsert_sequence(df_id, seq_num, item) + + # Import history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + self.save_history_tree(df_id, history_tree) + + return df_id + + # ------------------------------------------------------------------ + # Query helpers (for REST API) + # ------------------------------------------------------------------ + + def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None: + """Query a single sequence by project name, file name, and sequence number.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return None + return self.get_sequence(df["id"], sequence_number) + + def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]: + """Query keys and types for a sequence.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [], [] + return self.get_sequence_keys(df["id"], sequence_number) + + def list_project_files(self, project_name: str) -> list[dict]: + """List data files for a project by name.""" + proj = self.get_project(project_name) + if not proj: + return [] + return self.list_data_files(proj["id"]) + + def list_project_sequences(self, project_name: str, file_name: str) -> list[int]: + """List sequence numbers for a file in a project.""" + df = self.get_data_file_by_names(project_name, file_name) + if not df: + return [] + return self.list_sequences(df["id"]) diff --git a/main.py b/main.py index b6455fc..5936aa8 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import json +import logging from pathlib import Path from nicegui import ui @@ -14,6 +15,11 @@ from tab_batch_ng import render_batch_processor from tab_timeline_ng import render_timeline_tab from tab_raw_ng import render_raw_editor from tab_comfy_ng import render_comfy_monitor +from tab_projects_ng import render_projects_tab +from db import ProjectDB +from api_routes import register_api_routes + +logger = logging.getLogger(__name__) @ui.page('/') @@ -156,7 +162,17 @@ def index(): config=config, current_dir=Path(config.get('last_dir', str(Path.cwd()))), snippets=load_snippets(), + db_enabled=config.get('db_enabled', False), + current_project=config.get('current_project', ''), ) + + # Initialize project database + try: + state.db = ProjectDB() + except Exception as e: + logger.warning(f"Failed to initialize ProjectDB: {e}") + state.db = None + dual_pane = {'active': False, 'state': None} # ------------------------------------------------------------------ @@ -178,6 +194,7 @@ def index(): ui.tab('batch', label='Batch Processor') ui.tab('timeline', label='Timeline') ui.tab('raw', label='Raw Editor') + ui.tab('projects', label='Projects') with ui.tab_panels(tabs, value='batch').classes('w-full'): with ui.tab_panel('batch'): @@ -186,6 +203,8 @@ def index(): render_timeline_tab(state) with ui.tab_panel('raw'): render_raw_editor(state) + with ui.tab_panel('projects'): + render_projects_tab(state) if state.show_comfy_monitor: ui.separator() @@ -481,4 +500,11 @@ def render_sidebar(state: AppState, dual_pane: dict): ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) +# Register REST API routes for ComfyUI connectivity +try: + _api_db = ProjectDB() + register_api_routes(_api_db) +except Exception as e: + logger.warning(f"Failed to register API routes: {e}") + ui.run(title='AI Settings Manager', port=8080, reload=True) diff --git a/project_loader.py b/project_loader.py new file mode 100644 index 0000000..2adc6c3 --- /dev/null +++ b/project_loader.py @@ -0,0 +1,255 @@ +import json +import logging +import urllib.request +import urllib.error +from typing import Any + +logger = logging.getLogger(__name__) + +MAX_DYNAMIC_OUTPUTS = 32 + + +class AnyType(str): + """Universal connector type that matches any ComfyUI type.""" + def __ne__(self, __value: object) -> bool: + return False + +any_type = AnyType("*") + + +try: + from server import PromptServer + from aiohttp import web +except ImportError: + PromptServer = None + + +def to_float(val: Any) -> float: + try: + return float(val) + except (ValueError, TypeError): + return 0.0 + +def to_int(val: Any) -> int: + try: + return int(float(val)) + except (ValueError, TypeError): + return 0 + + +def _fetch_json(url: str) -> dict: + """Fetch JSON from a URL using stdlib urllib.""" + try: + with urllib.request.urlopen(url, timeout=5) as resp: + return json.loads(resp.read()) + except (urllib.error.URLError, json.JSONDecodeError, OSError) as e: + logger.warning(f"Failed to fetch {url}: {e}") + return {} + + +def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch sequence data from the NiceGUI REST API.""" + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}" + return _fetch_json(url) + + +def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: + """Fetch keys/types from the NiceGUI REST API.""" + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}" + return _fetch_json(url) + + +# --- ComfyUI-side proxy endpoints (for frontend JS) --- +if PromptServer is not None: + @PromptServer.instance.routes.get("/json_manager/list_projects") + async def list_projects_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + url = f"{manager_url.rstrip('/')}/api/projects" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_files") + async def list_project_files_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/list_project_sequences") + async def list_project_sequences_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + file_name = request.query.get("file", "") + url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" + data = _fetch_json(url) + return web.json_response(data) + + @PromptServer.instance.routes.get("/json_manager/get_project_keys") + async def get_project_keys_proxy(request): + manager_url = request.query.get("url", "http://localhost:8080") + project = request.query.get("project", "") + file_name = request.query.get("file", "") + try: + seq = int(request.query.get("seq", "1")) + except (ValueError, TypeError): + seq = 1 + data = _fetch_keys(manager_url, project, file_name, seq) + return web.json_response(data) + + +# ========================================== +# 0. DYNAMIC NODE (Project-based) +# ========================================== + +class ProjectLoaderDynamic: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + "optional": { + "output_keys": ("STRING", {"default": ""}), + "output_types": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) + RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) + FUNCTION = "load_dynamic" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + def load_dynamic(self, manager_url, project_name, file_name, sequence_number, + output_keys="", output_types=""): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + + keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] + + results = [] + for key in keys: + val = data.get(key, "") + if isinstance(val, bool): + results.append(str(val).lower()) + elif isinstance(val, int): + results.append(val) + elif isinstance(val, float): + results.append(val) + else: + results.append(str(val)) + + while len(results) < MAX_DYNAMIC_OUTPUTS: + results.append("") + + return tuple(results) + + +# ========================================== +# 1. STANDARD NODE (Project-based I2V) +# ========================================== + +class ProjectLoaderStandard: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING") + RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path") + FUNCTION = "load_standard" + CATEGORY = "utils/json/project" + + def load_standard(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("general_prompt", "")), str(data.get("general_negative", "")), + str(data.get("current_prompt", "")), str(data.get("negative", "")), + str(data.get("camera", "")), to_float(data.get("flf", 0.0)), + to_int(data.get("seed", 0)), str(data.get("video file path", "")), + str(data.get("reference image path", "")), str(data.get("flf image path", "")) + ) + + +# ========================================== +# 2. VACE NODE (Project-based) +# ========================================== + +class ProjectLoaderVACE: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING") + RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path") + FUNCTION = "load_vace" + CATEGORY = "utils/json/project" + + def load_vace(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("general_prompt", "")), str(data.get("general_negative", "")), + str(data.get("current_prompt", "")), str(data.get("negative", "")), + str(data.get("camera", "")), to_float(data.get("flf", 0.0)), + to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)), + to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)), + str(data.get("reference path", "")), to_int(data.get("reference switch", 1)), + to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")), + str(data.get("reference image path", "")) + ) + + +# ========================================== +# 3. LoRA NODE (Project-based) +# ========================================== + +class ProjectLoaderLoRA: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }} + + RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING") + RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low") + FUNCTION = "load_loras" + CATEGORY = "utils/json/project" + + def load_loras(self, manager_url, project_name, file_name, sequence_number): + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + return ( + str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")), + str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")), + str(data.get("lora 3 high", "")), str(data.get("lora 3 low", "")) + ) + + +# --- Mappings --- +PROJECT_NODE_CLASS_MAPPINGS = { + "ProjectLoaderDynamic": ProjectLoaderDynamic, + "ProjectLoaderStandard": ProjectLoaderStandard, + "ProjectLoaderVACE": ProjectLoaderVACE, + "ProjectLoaderLoRA": ProjectLoaderLoRA, +} + +PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { + "ProjectLoaderDynamic": "Project Loader (Dynamic)", + "ProjectLoaderStandard": "Project Loader (Standard/I2V)", + "ProjectLoaderVACE": "Project Loader (VACE Full)", + "ProjectLoaderLoRA": "Project Loader (LoRAs)", +} diff --git a/state.py b/state.py index e4aeab4..891a14e 100644 --- a/state.py +++ b/state.py @@ -17,6 +17,11 @@ class AppState: live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True + # Project DB fields + db: Any = None + current_project: str = "" + db_enabled: bool = False + # Set at runtime by main.py / tab_comfy_ng.py _render_main: Any = None _load_file: Callable | None = None diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 34845f6..36abee8 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -6,7 +6,7 @@ from nicegui import ui from state import AppState from utils import ( - DEFAULTS, save_json, load_json, + DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) from history_tree import HistoryTree @@ -161,6 +161,8 @@ def render_batch_processor(state: AppState): new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {}, KEY_PROMPT_HISTORY: []} save_json(new_path, new_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, new_path, new_data) ui.notify(f'Created {new_name}', type='positive') ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch) @@ -215,6 +217,8 @@ def render_batch_processor(state: AppState): batch_list.append(new_item) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) render_sequence_list.refresh() with ui.row().classes('q-mt-sm'): @@ -250,6 +254,8 @@ def render_batch_processor(state: AppState): batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0))) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Sorted by sequence number!', type='positive') render_sequence_list.refresh() @@ -289,6 +295,8 @@ def render_batch_processor(state: AppState): htree.commit(snapshot_payload, note=note) data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.restored_indicator = None commit_input.set_value('') ui.notify('Batch Saved & Snapshot Created!', type='positive') @@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, def commit(message=None): data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) if message: ui.notify(message, type='positive') refresh_list.refresh() @@ -567,6 +577,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): shifted += 1 data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive') refresh_list.refresh() @@ -712,6 +724,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify(f'Updated {len(targets)} sequences', type='positive') if refresh_list: refresh_list.refresh() diff --git a/tab_projects_ng.py b/tab_projects_ng.py new file mode 100644 index 0000000..afa8422 --- /dev/null +++ b/tab_projects_ng.py @@ -0,0 +1,161 @@ +import logging +from pathlib import Path + +from nicegui import ui + +from state import AppState +from db import ProjectDB +from utils import save_config, sync_to_db, KEY_BATCH_DATA + +logger = logging.getLogger(__name__) + + +def render_projects_tab(state: AppState): + """Render the Projects management tab.""" + + # --- DB toggle --- + def on_db_toggle(e): + state.db_enabled = e.value + state.config['db_enabled'] = e.value + save_config(state.current_dir, state.config.get('favorites', []), state.config) + render_project_content.refresh() + + ui.switch('Enable Project Database', value=state.db_enabled, + on_change=on_db_toggle).classes('q-mb-md') + + @ui.refreshable + def render_project_content(): + if not state.db_enabled: + ui.label('Project database is disabled. Enable it above to manage projects.').classes( + 'text-caption q-pa-md') + return + + if not state.db: + ui.label('Database not initialized.').classes('text-warning q-pa-md') + return + + # --- Create project form --- + with ui.card().classes('w-full q-pa-md q-mb-md'): + ui.label('Create New Project').classes('section-header') + name_input = ui.input('Project Name', placeholder='my_project').classes('w-full') + desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full') + + def create_project(): + name = name_input.value.strip() + if not name: + ui.notify('Please enter a project name', type='warning') + return + try: + state.db.create_project(name, str(state.current_dir), desc_input.value.strip()) + name_input.set_value('') + desc_input.set_value('') + ui.notify(f'Created project "{name}"', type='positive') + render_project_list.refresh() + except Exception as e: + ui.notify(f'Error: {e}', type='negative') + + ui.button('Create Project', icon='add', on_click=create_project).classes('w-full') + + # --- Active project indicator --- + if state.current_project: + ui.label(f'Active Project: {state.current_project}').classes( + 'text-bold text-primary q-pa-sm') + + # --- Project list --- + @ui.refreshable + def render_project_list(): + projects = state.db.list_projects() + if not projects: + ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md') + return + + for proj in projects: + is_active = proj['name'] == state.current_project + card_style = 'border-left: 3px solid var(--accent);' if is_active else '' + + with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style): + with ui.row().classes('w-full items-center'): + with ui.column().classes('col'): + ui.label(proj['name']).classes('text-bold') + if proj['description']: + ui.label(proj['description']).classes('text-caption') + ui.label(f'Path: {proj["folder_path"]}').classes('text-caption') + files = state.db.list_data_files(proj['id']) + ui.label(f'{len(files)} data file(s)').classes('text-caption') + + with ui.row().classes('q-gutter-xs'): + if not is_active: + def activate(name=proj['name']): + state.current_project = name + state.config['current_project'] = name + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify(f'Activated project "{name}"', type='positive') + render_project_list.refresh() + + ui.button('Activate', icon='check_circle', + on_click=activate).props('flat dense color=primary') + else: + def deactivate(): + state.current_project = '' + state.config['current_project'] = '' + save_config(state.current_dir, + state.config.get('favorites', []), + state.config) + ui.notify('Deactivated project', type='info') + render_project_list.refresh() + + ui.button('Deactivate', icon='cancel', + on_click=deactivate).props('flat dense') + + def import_folder(pid=proj['id'], pname=proj['name']): + _import_folder(state, pid, pname, render_project_list) + + ui.button('Import Folder', icon='folder_open', + on_click=import_folder).props('flat dense') + + def delete_proj(name=proj['name']): + state.db.delete_project(name) + if state.current_project == name: + state.current_project = '' + ui.notify(f'Deleted project "{name}"', type='positive') + render_project_list.refresh() + + ui.button(icon='delete', + on_click=delete_proj).props('flat dense color=negative') + + render_project_list() + + render_project_content() + + +def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn): + """Bulk import all .json files from current directory into a project.""" + json_files = sorted(state.current_dir.glob('*.json')) + json_files = [f for f in json_files if f.name not in ( + '.editor_config.json', '.editor_snippets.json')] + + if not json_files: + ui.notify('No JSON files in current directory', type='warning') + return + + imported = 0 + skipped = 0 + for jf in json_files: + file_name = jf.stem + existing = state.db.get_data_file(project_id, file_name) + if existing: + skipped += 1 + continue + try: + state.db.import_json_file(project_id, jf) + imported += 1 + except Exception as e: + logger.warning(f"Failed to import {jf}: {e}") + + msg = f'Imported {imported} file(s)' + if skipped: + msg += f', skipped {skipped} existing' + ui.notify(msg, type='positive') + refresh_fn.refresh() diff --git a/tab_raw_ng.py b/tab_raw_ng.py index 39ec6f3..bdc4933 100644 --- a/tab_raw_ng.py +++ b/tab_raw_ng.py @@ -4,7 +4,7 @@ import json from nicegui import ui from state import AppState -from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY +from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY def render_raw_editor(state: AppState): @@ -52,6 +52,8 @@ def render_raw_editor(state: AppState): input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY] save_json(file_path, input_data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, input_data) data.clear() data.update(input_data) diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index daaf460..d0467f4 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -5,7 +5,7 @@ from nicegui import ui from state import AppState from history_tree import HistoryTree -from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE +from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE def _delete_nodes(htree, data, file_path, node_ids): @@ -134,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn): def do_batch_delete(): current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) _delete_nodes(htree, data, file_path, current_valid) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) state.timeline_selected_nodes = set() ui.notify( f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', @@ -179,7 +181,7 @@ def _find_branch_for_node(htree, node_id): def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected): + selected, state=None): """Render branch-grouped node manager with restore, rename, delete, and preview.""" ui.label('Manage Version').classes('section-header') @@ -291,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ htree.nodes[sel_id]['note'] = rename_input.value data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Label updated', type='positive') refresh_fn() @@ -304,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_ def delete_selected(): if sel_id in htree.nodes: _delete_nodes(htree, data, file_path, {sel_id}) + if state and state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) ui.notify('Node Deleted', type='positive') refresh_fn() @@ -377,7 +383,7 @@ def render_timeline_tab(state: AppState): _render_node_manager( all_nodes, htree, data, file_path, _restore_and_refresh, render_timeline.refresh, - selected) + selected, state=state) def _toggle_select(nid, checked): if checked: @@ -492,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState): htree.head_id = node['id'] data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) + if state.db_enabled and state.current_project and state.db: + sync_to_db(state.db, state.current_project, file_path, data) label = f"{node.get('note', 'Step')} ({node['id'][:4]})" state.restored_indicator = label ui.notify('Restored!', type='positive') diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..73427c1 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,286 @@ +import json +from pathlib import Path + +import pytest + +from db import ProjectDB +from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE + + +@pytest.fixture +def db(tmp_path): + """Create a fresh ProjectDB in a temp directory.""" + db_path = tmp_path / "test.db" + pdb = ProjectDB(db_path) + yield pdb + pdb.close() + + +# ------------------------------------------------------------------ +# Projects CRUD +# ------------------------------------------------------------------ + +class TestProjects: + def test_create_and_get(self, db): + pid = db.create_project("proj1", "/some/path", "A test project") + assert pid > 0 + proj = db.get_project("proj1") + assert proj is not None + assert proj["name"] == "proj1" + assert proj["folder_path"] == "/some/path" + assert proj["description"] == "A test project" + + def test_list_projects(self, db): + db.create_project("beta", "/b") + db.create_project("alpha", "/a") + projects = db.list_projects() + assert len(projects) == 2 + assert projects[0]["name"] == "alpha" + assert projects[1]["name"] == "beta" + + def test_get_nonexistent(self, db): + assert db.get_project("nope") is None + + def test_delete_project(self, db): + db.create_project("to_delete", "/x") + assert db.delete_project("to_delete") is True + assert db.get_project("to_delete") is None + + def test_delete_nonexistent(self, db): + assert db.delete_project("nope") is False + + def test_unique_name_constraint(self, db): + db.create_project("dup", "/a") + with pytest.raises(Exception): + db.create_project("dup", "/b") + + +# ------------------------------------------------------------------ +# Data files +# ------------------------------------------------------------------ + +class TestDataFiles: + def test_create_and_list(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"}) + assert df_id > 0 + files = db.list_data_files(pid) + assert len(files) == 1 + assert files[0]["name"] == "batch_i2v" + assert files[0]["data_type"] == "i2v" + + def test_get_data_file(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"}) + df = db.get_data_file(pid, "batch_i2v") + assert df is not None + assert df["top_level"] == {"key": "value"} + + def test_get_data_file_by_names(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + df = db.get_data_file_by_names("p1", "batch_i2v") + assert df is not None + assert df["name"] == "batch_i2v" + + def test_get_nonexistent_data_file(self, db): + pid = db.create_project("p1", "/p1") + assert db.get_data_file(pid, "nope") is None + + def test_unique_constraint(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "batch_i2v", "i2v") + with pytest.raises(Exception): + db.create_data_file(pid, "batch_i2v", "vace") + + def test_cascade_delete(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "hello"}) + db.save_history_tree(df_id, {"nodes": {}}) + db.delete_project("p1") + assert db.get_data_file(pid, "batch_i2v") is None + + +# ------------------------------------------------------------------ +# Sequences +# ------------------------------------------------------------------ + +class TestSequences: + def test_upsert_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42}) + data = db.get_sequence(df_id, 1) + assert data == {"prompt": "hello", "seed": 42} + + def test_upsert_updates_existing(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"prompt": "v1"}) + db.upsert_sequence(df_id, 1, {"prompt": "v2"}) + data = db.get_sequence(df_id, 1) + assert data["prompt"] == "v2" + + def test_list_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 3, {"a": 1}) + db.upsert_sequence(df_id, 1, {"b": 2}) + db.upsert_sequence(df_id, 2, {"c": 3}) + seqs = db.list_sequences(df_id) + assert seqs == [1, 2, 3] + + def test_get_nonexistent_sequence(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_sequence(df_id, 99) is None + + def test_get_sequence_keys(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, { + "prompt": "hello", + "seed": 42, + "cfg": 1.5, + "flag": True, + }) + keys, types = db.get_sequence_keys(df_id, 1) + assert "prompt" in keys + assert "seed" in keys + idx_prompt = keys.index("prompt") + idx_seed = keys.index("seed") + idx_cfg = keys.index("cfg") + idx_flag = keys.index("flag") + assert types[idx_prompt] == "STRING" + assert types[idx_seed] == "INT" + assert types[idx_cfg] == "FLOAT" + assert types[idx_flag] == "STRING" # bools -> STRING + + def test_get_sequence_keys_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + keys, types = db.get_sequence_keys(df_id, 99) + assert keys == [] + assert types == [] + + def test_delete_sequences_for_file(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {"a": 1}) + db.upsert_sequence(df_id, 2, {"b": 2}) + db.delete_sequences_for_file(df_id) + assert db.list_sequences(df_id) == [] + + +# ------------------------------------------------------------------ +# History trees +# ------------------------------------------------------------------ + +class TestHistoryTrees: + def test_save_and_get(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"} + db.save_history_tree(df_id, tree) + result = db.get_history_tree(df_id) + assert result == tree + + def test_upsert_updates(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.save_history_tree(df_id, {"v": 1}) + db.save_history_tree(df_id, {"v": 2}) + result = db.get_history_tree(df_id) + assert result == {"v": 2} + + def test_get_nonexistent(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + assert db.get_history_tree(df_id) is None + + +# ------------------------------------------------------------------ +# Import +# ------------------------------------------------------------------ + +class TestImport: + def test_import_json_file(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch_prompt_i2v.json" + data = { + KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "hello", "seed": 42}, + {"sequence_number": 2, "prompt": "world", "seed": 99}, + ], + KEY_HISTORY_TREE: {"nodes": {}, "head_id": None}, + } + json_path.write_text(json.dumps(data)) + + df_id = db.import_json_file(pid, json_path, "i2v") + assert df_id > 0 + + seqs = db.list_sequences(df_id) + assert seqs == [1, 2] + + s1 = db.get_sequence(df_id, 1) + assert s1["prompt"] == "hello" + assert s1["seed"] == 42 + + tree = db.get_history_tree(df_id) + assert tree == {"nodes": {}, "head_id": None} + + def test_import_file_name_from_stem(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "my_batch.json" + json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]})) + db.import_json_file(pid, json_path) + df = db.get_data_file(pid, "my_batch") + assert df is not None + + def test_import_no_batch_data(self, db, tmp_path): + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "simple.json" + json_path.write_text(json.dumps({"prompt": "flat file"})) + df_id = db.import_json_file(pid, json_path) + seqs = db.list_sequences(df_id) + assert seqs == [] + + +# ------------------------------------------------------------------ +# Query helpers +# ------------------------------------------------------------------ + +class TestQueryHelpers: + def test_query_sequence_data(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + result = db.query_sequence_data("myproject", "batch_i2v", 1) + assert result == {"prompt": "test", "seed": 7} + + def test_query_sequence_data_not_found(self, db): + assert db.query_sequence_data("nope", "nope", 1) is None + + def test_query_sequence_keys(self, db): + pid = db.create_project("myproject", "/mp") + df_id = db.create_data_file(pid, "batch_i2v", "i2v") + db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7}) + keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1) + assert "prompt" in keys + assert "seed" in keys + + def test_list_project_files(self, db): + pid = db.create_project("p1", "/p1") + db.create_data_file(pid, "file_a", "i2v") + db.create_data_file(pid, "file_b", "vace") + files = db.list_project_files("p1") + assert len(files) == 2 + + def test_list_project_sequences(self, db): + pid = db.create_project("p1", "/p1") + df_id = db.create_data_file(pid, "batch", "generic") + db.upsert_sequence(df_id, 1, {}) + db.upsert_sequence(df_id, 2, {}) + seqs = db.list_project_sequences("p1", "batch") + assert seqs == [1, 2] diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py new file mode 100644 index 0000000..00a59a7 --- /dev/null +++ b/tests/test_project_loader.py @@ -0,0 +1,201 @@ +import json +from unittest.mock import patch, MagicMock +from io import BytesIO + +import pytest + +from project_loader import ( + ProjectLoaderDynamic, + ProjectLoaderStandard, + ProjectLoaderVACE, + ProjectLoaderLoRA, + _fetch_json, + _fetch_data, + _fetch_keys, + MAX_DYNAMIC_OUTPUTS, +) + + +def _mock_urlopen(data: dict): + """Create a mock context manager for urllib.request.urlopen.""" + response = MagicMock() + response.read.return_value = json.dumps(data).encode() + response.__enter__ = lambda s: s + response.__exit__ = MagicMock(return_value=False) + return response + + +class TestFetchHelpers: + def test_fetch_json_success(self): + data = {"key": "value"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)): + result = _fetch_json("http://example.com/api") + assert result == data + + def test_fetch_json_failure(self): + import urllib.error + with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")): + result = _fetch_json("http://example.com/api") + assert result == {} + + def test_fetch_data_builds_url(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url + + def test_fetch_keys_builds_url(self): + data = {"keys": ["prompt"], "types": ["STRING"]} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1) + assert result == data + called_url = mock.call_args[0][0] + assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url + + def test_fetch_data_strips_trailing_slash(self): + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080/", "proj1", "file1", 1) + called_url = mock.call_args[0][0] + assert "//api" not in called_url + + +class TestProjectLoaderDynamic: + def test_load_dynamic_with_keys(self): + data = {"prompt": "hello", "seed": 42, "cfg": 1.5} + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="prompt,seed,cfg" + ) + assert result[0] == "hello" + assert result[1] == 42 + assert result[2] == 1.5 + assert len(result) == MAX_DYNAMIC_OUTPUTS + + def test_load_dynamic_empty_keys(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="" + ) + assert all(v == "" for v in result) + + def test_load_dynamic_missing_key(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="nonexistent" + ) + assert result[0] == "" + + def test_load_dynamic_bool_becomes_string(self): + node = ProjectLoaderDynamic() + with patch("project_loader._fetch_data", return_value={"flag": True}): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys="flag" + ) + assert result[0] == "true" + + def test_input_types_has_manager_url(self): + inputs = ProjectLoaderDynamic.INPUT_TYPES() + assert "manager_url" in inputs["required"] + assert "project_name" in inputs["required"] + assert "file_name" in inputs["required"] + assert "sequence_number" in inputs["required"] + + def test_category(self): + assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" + + +class TestProjectLoaderStandard: + def test_load_standard(self): + data = { + "general_prompt": "hello", + "general_negative": "bad", + "current_prompt": "specific", + "negative": "neg", + "camera": "pan", + "flf": 0.5, + "seed": 42, + "video file path": "/v.mp4", + "reference image path": "/r.png", + "flf image path": "/f.png", + } + node = ProjectLoaderStandard() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) + assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png") + + def test_load_standard_defaults(self): + node = ProjectLoaderStandard() + with patch("project_loader._fetch_data", return_value={}): + result = node.load_standard("http://localhost:8080", "proj1", "batch", 1) + assert result[0] == "" # general_prompt + assert result[5] == 0.0 # flf + assert result[6] == 0 # seed + + +class TestProjectLoaderVACE: + def test_load_vace(self): + data = { + "general_prompt": "hello", + "general_negative": "bad", + "current_prompt": "specific", + "negative": "neg", + "camera": "pan", + "flf": 0.5, + "seed": 42, + "frame_to_skip": 81, + "input_a_frames": 16, + "input_b_frames": 16, + "reference path": "/ref", + "reference switch": 1, + "vace schedule": 2, + "video file path": "/v.mp4", + "reference image path": "/r.png", + } + node = ProjectLoaderVACE() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_vace("http://localhost:8080", "proj1", "batch", 1) + assert result[7] == 81 # frame_to_skip + assert result[12] == 2 # vace_schedule + + +class TestProjectLoaderLoRA: + def test_load_loras(self): + data = { + "lora 1 high": "", + "lora 1 low": "", + "lora 2 high": "", + "lora 2 low": "", + "lora 3 high": "", + "lora 3 low": "", + } + node = ProjectLoaderLoRA() + with patch("project_loader._fetch_data", return_value=data): + result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) + assert result[0] == "" + assert result[1] == "" + + def test_load_loras_empty(self): + node = ProjectLoaderLoRA() + with patch("project_loader._fetch_data", return_value={}): + result = node.load_loras("http://localhost:8080", "proj1", "batch", 1) + assert all(v == "" for v in result) + + +class TestNodeMappings: + def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 diff --git a/utils.py b/utils.py index 2e49007..ea5cdc9 100644 --- a/utils.py +++ b/utils.py @@ -160,6 +160,43 @@ def get_file_mtime(path: str | Path) -> float: return path.stat().st_mtime return 0 +def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: + """Dual-write helper: sync JSON data to the project database. + + Resolves (or creates) the data_file, upserts all sequences from batch_data, + and saves the history_tree. + """ + if not db or not project_name: + return + try: + proj = db.get_project(project_name) + if not proj: + return + file_name = Path(file_path).stem + df = db.get_data_file(proj["id"], file_name) + if not df: + top_level = {k: v for k, v in data.items() + if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} + df_id = db.create_data_file(proj["id"], file_name, "generic", top_level) + else: + df_id = df["id"] + + # Sync sequences + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + db.delete_sequences_for_file(df_id) + for item in batch_data: + seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) + db.upsert_sequence(df_id, seq_num, item) + + # Sync history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + db.save_history_tree(df_id, history_tree) + except Exception as e: + logger.warning(f"sync_to_db failed: {e}") + + def generate_templates(current_dir: Path) -> None: """Creates batch template files if folder is empty.""" first = DEFAULTS.copy() diff --git a/web/project_dynamic.js b/web/project_dynamic.js new file mode 100644 index 0000000..9d58b78 --- /dev/null +++ b/web/project_dynamic.js @@ -0,0 +1,139 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.dynamic", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectLoaderDynamic") return; + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + + // Hide internal widgets (managed by JS) + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + // Remove all 32 default outputs from Python RETURN_TYPES + while (this.outputs.length > 0) { + this.removeOutput(0); + } + + // Add Refresh button + this.addWidget("button", "Refresh Outputs", null, () => { + this.refreshDynamicOutputs(); + }); + + this.setSize(this.computeSize()); + }; + + nodeType.prototype.refreshDynamicOutputs = async function () { + const urlWidget = this.widgets?.find(w => w.name === "manager_url"); + const projectWidget = this.widgets?.find(w => w.name === "project_name"); + const fileWidget = this.widgets?.find(w => w.name === "file_name"); + const seqWidget = this.widgets?.find(w => w.name === "sequence_number"); + + if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return; + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` + ); + const { keys, types } = await resp.json(); + + // Store keys and types in hidden widgets for persistence + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + if (okWidget) okWidget.value = keys.join(","); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + if (otWidget) otWidget.value = types.join(","); + + // Build a map of current output names to slot indices + const oldSlots = {}; + for (let i = 0; i < this.outputs.length; i++) { + oldSlots[this.outputs[i].name] = i; + } + + // Build new outputs, reusing existing slots to preserve links + const newOutputs = []; + for (let k = 0; k < keys.length; k++) { + const key = keys[k]; + const type = types[k] || "*"; + if (key in oldSlots) { + const slot = this.outputs[oldSlots[key]]; + slot.type = type; + newOutputs.push(slot); + delete oldSlots[key]; + } else { + newOutputs.push({ name: key, type: type, links: null }); + } + } + + // Disconnect links on slots that are being removed + for (const name in oldSlots) { + const idx = oldSlots[name]; + if (this.outputs[idx]?.links?.length) { + for (const linkId of [...this.outputs[idx].links]) { + this.graph?.removeLink(linkId); + } + } + } + + // Reassign the outputs array and fix link slot indices + this.outputs = newOutputs; + if (this.graph) { + for (let i = 0; i < this.outputs.length; i++) { + const links = this.outputs[i].links; + if (!links) continue; + for (const linkId of links) { + const link = this.graph.links[linkId]; + if (link) link.origin_slot = i; + } + } + } + + this.setSize(this.computeSize()); + app.graph.setDirtyCanvas(true, true); + } catch (e) { + console.error("[ProjectLoaderDynamic] Refresh failed:", e); + } + }; + + // Restore state on workflow load + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + + // Hide internal widgets + for (const name of ["output_keys", "output_types"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; } + } + + const okWidget = this.widgets?.find(w => w.name === "output_keys"); + const otWidget = this.widgets?.find(w => w.name === "output_types"); + + const keys = okWidget?.value + ? okWidget.value.split(",").filter(k => k.trim()) + : []; + const types = otWidget?.value + ? otWidget.value.split(",") + : []; + + // Rename and set types to match stored state (preserves links) + for (let i = 0; i < this.outputs.length && i < keys.length; i++) { + this.outputs[i].name = keys[i].trim(); + if (types[i]) this.outputs[i].type = types[i]; + } + + // Remove any extra outputs beyond the key count + while (this.outputs.length > keys.length) { + this.removeOutput(this.outputs.length - 1); + } + + this.setSize(this.computeSize()); + }; + }, +});