Add SQLite project database + ComfyUI connector nodes

- db.py: ProjectDB class with SQLite schema (projects, data_files,
  sequences, history_trees), WAL mode, CRUD, import, and query helpers
- api_routes.py: REST API endpoints on NiceGUI/FastAPI for ComfyUI
  to query project data over the network
- project_loader.py: ComfyUI nodes (ProjectLoaderDynamic, Standard,
  VACE, LoRA) that fetch data from NiceGUI REST API via HTTP
- web/project_dynamic.js: Frontend JS for dynamic project loader node
- tab_projects_ng.py: Projects management tab in NiceGUI UI
- state.py: Added db, current_project, db_enabled fields
- main.py: DB init, API route registration, projects tab
- utils.py: sync_to_db() dual-write helper
- tab_batch_ng.py, tab_raw_ng.py, tab_timeline_ng.py: dual-write
  sync calls after save_json when project DB is enabled
- __init__.py: Merged project node class mappings
- tests/test_db.py: 30 tests for database layer
- tests/test_project_loader.py: 17 tests for ComfyUI connector nodes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-28 21:12:05 +01:00
parent 0d8e84ea36
commit c15bec98ce
14 changed files with 1495 additions and 5 deletions

View File

@@ -1,4 +1,8 @@
from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
NODE_CLASS_MAPPINGS.update(PROJECT_NODE_CLASS_MAPPINGS)
NODE_DISPLAY_NAME_MAPPINGS.update(PROJECT_NODE_DISPLAY_NAME_MAPPINGS)
WEB_DIRECTORY = "./web"

67
api_routes.py Normal file
View File

@@ -0,0 +1,67 @@
"""REST API endpoints for ComfyUI to query project data from SQLite.
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
"""
import logging
from typing import Any
from fastapi import HTTPException, Query
from nicegui import app
from db import ProjectDB
logger = logging.getLogger(__name__)
# The DB instance is set by register_api_routes()
_db: ProjectDB | None = None
def register_api_routes(db: ProjectDB) -> None:
"""Register all REST API routes with the NiceGUI/FastAPI app."""
global _db
_db = db
app.add_api_route("/api/projects", _list_projects, methods=["GET"])
app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"])
app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"])
app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"])
app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"])
def _get_db() -> ProjectDB:
if _db is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return _db
async def _list_projects() -> dict[str, Any]:
db = _get_db()
projects = db.list_projects()
return {"projects": [p["name"] for p in projects]}
async def _list_files(name: str) -> dict[str, Any]:
db = _get_db()
files = db.list_project_files(name)
return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]}
async def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
db = _get_db()
seqs = db.list_project_sequences(name, file_name)
return {"sequences": seqs}
async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
db = _get_db()
data = db.query_sequence_data(name, file_name, seq)
if data is None:
raise HTTPException(status_code=404, detail="Sequence not found")
return data
async def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
db = _get_db()
keys, types = db.query_sequence_keys(name, file_name, seq)
return {"keys": keys, "types": types}

285
db.py Normal file
View File

@@ -0,0 +1,285 @@
import json
import logging
import sqlite3
import time
from pathlib import Path
from typing import Any
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
logger = logging.getLogger(__name__)
DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db"
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
folder_path TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS data_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
name TEXT NOT NULL,
data_type TEXT NOT NULL DEFAULT 'generic',
top_level TEXT NOT NULL DEFAULT '{}',
created_at REAL NOT NULL,
updated_at REAL NOT NULL,
UNIQUE(project_id, name)
);
CREATE TABLE IF NOT EXISTS sequences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
sequence_number INTEGER NOT NULL,
data TEXT NOT NULL DEFAULT '{}',
updated_at REAL NOT NULL,
UNIQUE(data_file_id, sequence_number)
);
CREATE TABLE IF NOT EXISTS history_trees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
tree_data TEXT NOT NULL DEFAULT '{}',
updated_at REAL NOT NULL
);
"""
class ProjectDB:
"""SQLite database for project-based data management."""
def __init__(self, db_path: str | Path | None = None):
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA journal_mode=WAL")
self.conn.execute("PRAGMA foreign_keys=ON")
self.conn.executescript(SCHEMA_SQL)
self.conn.commit()
def close(self):
self.conn.close()
# ------------------------------------------------------------------
# Projects CRUD
# ------------------------------------------------------------------
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
now = time.time()
cur = self.conn.execute(
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?)",
(name, folder_path, description, now, now),
)
self.conn.commit()
return cur.lastrowid
def list_projects(self) -> list[dict]:
rows = self.conn.execute(
"SELECT id, name, folder_path, description, created_at, updated_at "
"FROM projects ORDER BY name"
).fetchall()
return [dict(r) for r in rows]
def get_project(self, name: str) -> dict | None:
row = self.conn.execute(
"SELECT id, name, folder_path, description, created_at, updated_at "
"FROM projects WHERE name = ?",
(name,),
).fetchone()
return dict(row) if row else None
def delete_project(self, name: str) -> bool:
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
self.conn.commit()
return cur.rowcount > 0
# ------------------------------------------------------------------
# Data files
# ------------------------------------------------------------------
def create_data_file(
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
) -> int:
now = time.time()
tl = json.dumps(top_level or {})
cur = self.conn.execute(
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(project_id, name, data_type, tl, now, now),
)
self.conn.commit()
return cur.lastrowid
def list_data_files(self, project_id: int) -> list[dict]:
rows = self.conn.execute(
"SELECT id, project_id, name, data_type, created_at, updated_at "
"FROM data_files WHERE project_id = ? ORDER BY name",
(project_id,),
).fetchall()
return [dict(r) for r in rows]
def get_data_file(self, project_id: int, name: str) -> dict | None:
row = self.conn.execute(
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
"FROM data_files WHERE project_id = ? AND name = ?",
(project_id, name),
).fetchone()
if row is None:
return None
d = dict(row)
d["top_level"] = json.loads(d["top_level"])
return d
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
row = self.conn.execute(
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
"df.created_at, df.updated_at "
"FROM data_files df JOIN projects p ON df.project_id = p.id "
"WHERE p.name = ? AND df.name = ?",
(project_name, file_name),
).fetchone()
if row is None:
return None
d = dict(row)
d["top_level"] = json.loads(d["top_level"])
return d
# ------------------------------------------------------------------
# Sequences
# ------------------------------------------------------------------
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
now = time.time()
self.conn.execute(
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
"VALUES (?, ?, ?, ?) "
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
(data_file_id, sequence_number, json.dumps(data), now),
)
self.conn.commit()
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
row = self.conn.execute(
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
(data_file_id, sequence_number),
).fetchone()
return json.loads(row["data"]) if row else None
def list_sequences(self, data_file_id: int) -> list[int]:
rows = self.conn.execute(
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
(data_file_id,),
).fetchall()
return [r["sequence_number"] for r in rows]
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
"""Returns (keys, types) for a sequence's data dict."""
data = self.get_sequence(data_file_id, sequence_number)
if not data:
return [], []
keys = []
types = []
for k, v in data.items():
keys.append(k)
if isinstance(v, bool):
types.append("STRING")
elif isinstance(v, int):
types.append("INT")
elif isinstance(v, float):
types.append("FLOAT")
else:
types.append("STRING")
return keys, types
def delete_sequences_for_file(self, data_file_id: int) -> None:
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
self.conn.commit()
# ------------------------------------------------------------------
# History trees
# ------------------------------------------------------------------
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
now = time.time()
self.conn.execute(
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
"VALUES (?, ?, ?) "
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
(data_file_id, json.dumps(tree_data), now),
)
self.conn.commit()
def get_history_tree(self, data_file_id: int) -> dict | None:
row = self.conn.execute(
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
(data_file_id,),
).fetchone()
return json.loads(row["tree_data"]) if row else None
# ------------------------------------------------------------------
# Import
# ------------------------------------------------------------------
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
"""Import a JSON file into the database, splitting batch_data into sequences."""
json_path = Path(json_path)
data, _ = load_json(json_path)
file_name = json_path.stem
# Extract top-level keys that aren't batch_data or history_tree
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
# Import sequences from batch_data
batch_data = data.get(KEY_BATCH_DATA, [])
if isinstance(batch_data, list):
for item in batch_data:
seq_num = int(item.get("sequence_number", 0))
self.upsert_sequence(df_id, seq_num, item)
# Import history tree
history_tree = data.get(KEY_HISTORY_TREE)
if history_tree and isinstance(history_tree, dict):
self.save_history_tree(df_id, history_tree)
return df_id
# ------------------------------------------------------------------
# Query helpers (for REST API)
# ------------------------------------------------------------------
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
"""Query a single sequence by project name, file name, and sequence number."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return None
return self.get_sequence(df["id"], sequence_number)
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
"""Query keys and types for a sequence."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return [], []
return self.get_sequence_keys(df["id"], sequence_number)
def list_project_files(self, project_name: str) -> list[dict]:
"""List data files for a project by name."""
proj = self.get_project(project_name)
if not proj:
return []
return self.list_data_files(proj["id"])
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
"""List sequence numbers for a file in a project."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return []
return self.list_sequences(df["id"])

26
main.py
View File

@@ -1,4 +1,5 @@
import json
import logging
from pathlib import Path
from nicegui import ui
@@ -14,6 +15,11 @@ from tab_batch_ng import render_batch_processor
from tab_timeline_ng import render_timeline_tab
from tab_raw_ng import render_raw_editor
from tab_comfy_ng import render_comfy_monitor
from tab_projects_ng import render_projects_tab
from db import ProjectDB
from api_routes import register_api_routes
logger = logging.getLogger(__name__)
@ui.page('/')
@@ -156,7 +162,17 @@ def index():
config=config,
current_dir=Path(config.get('last_dir', str(Path.cwd()))),
snippets=load_snippets(),
db_enabled=config.get('db_enabled', False),
current_project=config.get('current_project', ''),
)
# Initialize project database
try:
state.db = ProjectDB()
except Exception as e:
logger.warning(f"Failed to initialize ProjectDB: {e}")
state.db = None
dual_pane = {'active': False, 'state': None}
# ------------------------------------------------------------------
@@ -178,6 +194,7 @@ def index():
ui.tab('batch', label='Batch Processor')
ui.tab('timeline', label='Timeline')
ui.tab('raw', label='Raw Editor')
ui.tab('projects', label='Projects')
with ui.tab_panels(tabs, value='batch').classes('w-full'):
with ui.tab_panel('batch'):
@@ -186,6 +203,8 @@ def index():
render_timeline_tab(state)
with ui.tab_panel('raw'):
render_raw_editor(state)
with ui.tab_panel('projects'):
render_projects_tab(state)
if state.show_comfy_monitor:
ui.separator()
@@ -481,4 +500,11 @@ def render_sidebar(state: AppState, dual_pane: dict):
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
# Register REST API routes for ComfyUI connectivity
try:
_api_db = ProjectDB()
register_api_routes(_api_db)
except Exception as e:
logger.warning(f"Failed to register API routes: {e}")
ui.run(title='AI Settings Manager', port=8080, reload=True)

255
project_loader.py Normal file
View File

@@ -0,0 +1,255 @@
import json
import logging
import urllib.request
import urllib.error
from typing import Any
logger = logging.getLogger(__name__)
MAX_DYNAMIC_OUTPUTS = 32
class AnyType(str):
"""Universal connector type that matches any ComfyUI type."""
def __ne__(self, __value: object) -> bool:
return False
any_type = AnyType("*")
try:
from server import PromptServer
from aiohttp import web
except ImportError:
PromptServer = None
def to_float(val: Any) -> float:
try:
return float(val)
except (ValueError, TypeError):
return 0.0
def to_int(val: Any) -> int:
try:
return int(float(val))
except (ValueError, TypeError):
return 0
def _fetch_json(url: str) -> dict:
"""Fetch JSON from a URL using stdlib urllib."""
try:
with urllib.request.urlopen(url, timeout=5) as resp:
return json.loads(resp.read())
except (urllib.error.URLError, json.JSONDecodeError, OSError) as e:
logger.warning(f"Failed to fetch {url}: {e}")
return {}
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
"""Fetch sequence data from the NiceGUI REST API."""
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}"
return _fetch_json(url)
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
"""Fetch keys/types from the NiceGUI REST API."""
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}"
return _fetch_json(url)
# --- ComfyUI-side proxy endpoints (for frontend JS) ---
if PromptServer is not None:
@PromptServer.instance.routes.get("/json_manager/list_projects")
async def list_projects_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080")
url = f"{manager_url.rstrip('/')}/api/projects"
data = _fetch_json(url)
return web.json_response(data)
@PromptServer.instance.routes.get("/json_manager/list_project_files")
async def list_project_files_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080")
project = request.query.get("project", "")
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
data = _fetch_json(url)
return web.json_response(data)
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
async def list_project_sequences_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080")
project = request.query.get("project", "")
file_name = request.query.get("file", "")
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
data = _fetch_json(url)
return web.json_response(data)
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
async def get_project_keys_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080")
project = request.query.get("project", "")
file_name = request.query.get("file", "")
try:
seq = int(request.query.get("seq", "1"))
except (ValueError, TypeError):
seq = 1
data = _fetch_keys(manager_url, project, file_name, seq)
return web.json_response(data)
# ==========================================
# 0. DYNAMIC NODE (Project-based)
# ==========================================
class ProjectLoaderDynamic:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
},
"optional": {
"output_keys": ("STRING", {"default": ""}),
"output_types": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
FUNCTION = "load_dynamic"
CATEGORY = "utils/json/project"
OUTPUT_NODE = False
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
output_keys="", output_types=""):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else []
results = []
for key in keys:
val = data.get(key, "")
if isinstance(val, bool):
results.append(str(val).lower())
elif isinstance(val, int):
results.append(val)
elif isinstance(val, float):
results.append(val)
else:
results.append(str(val))
while len(results) < MAX_DYNAMIC_OUTPUTS:
results.append("")
return tuple(results)
# ==========================================
# 1. STANDARD NODE (Project-based I2V)
# ==========================================
class ProjectLoaderStandard:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
FUNCTION = "load_standard"
CATEGORY = "utils/json/project"
def load_standard(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
str(data.get("current_prompt", "")), str(data.get("negative", "")),
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
)
# ==========================================
# 2. VACE NODE (Project-based)
# ==========================================
class ProjectLoaderVACE:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
FUNCTION = "load_vace"
CATEGORY = "utils/json/project"
def load_vace(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
str(data.get("current_prompt", "")), str(data.get("negative", "")),
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
to_int(data.get("seed", 0)), to_int(data.get("frame_to_skip", 81)),
to_int(data.get("input_a_frames", 16)), to_int(data.get("input_b_frames", 16)),
str(data.get("reference path", "")), to_int(data.get("reference switch", 1)),
to_int(data.get("vace schedule", 1)), str(data.get("video file path", "")),
str(data.get("reference image path", ""))
)
# ==========================================
# 3. LoRA NODE (Project-based)
# ==========================================
class ProjectLoaderLoRA:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
"project_name": ("STRING", {"default": "", "multiline": False}),
"file_name": ("STRING", {"default": "", "multiline": False}),
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
}}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
FUNCTION = "load_loras"
CATEGORY = "utils/json/project"
def load_loras(self, manager_url, project_name, file_name, sequence_number):
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
return (
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
)
# --- Mappings ---
PROJECT_NODE_CLASS_MAPPINGS = {
"ProjectLoaderDynamic": ProjectLoaderDynamic,
"ProjectLoaderStandard": ProjectLoaderStandard,
"ProjectLoaderVACE": ProjectLoaderVACE,
"ProjectLoaderLoRA": ProjectLoaderLoRA,
}
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
"ProjectLoaderStandard": "Project Loader (Standard/I2V)",
"ProjectLoaderVACE": "Project Loader (VACE Full)",
"ProjectLoaderLoRA": "Project Loader (LoRAs)",
}

View File

@@ -17,6 +17,11 @@ class AppState:
live_toggles: dict = field(default_factory=dict)
show_comfy_monitor: bool = True
# Project DB fields
db: Any = None
current_project: str = ""
db_enabled: bool = False
# Set at runtime by main.py / tab_comfy_ng.py
_render_main: Any = None
_load_file: Callable | None = None

View File

@@ -6,7 +6,7 @@ from nicegui import ui
from state import AppState
from utils import (
DEFAULTS, save_json, load_json,
DEFAULTS, save_json, load_json, sync_to_db,
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
)
from history_tree import HistoryTree
@@ -161,6 +161,8 @@ def render_batch_processor(state: AppState):
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
KEY_PROMPT_HISTORY: []}
save_json(new_path, new_data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, new_path, new_data)
ui.notify(f'Created {new_name}', type='positive')
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
@@ -215,6 +217,8 @@ def render_batch_processor(state: AppState):
batch_list.append(new_item)
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
render_sequence_list.refresh()
with ui.row().classes('q-mt-sm'):
@@ -250,6 +254,8 @@ def render_batch_processor(state: AppState):
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
ui.notify('Sorted by sequence number!', type='positive')
render_sequence_list.refresh()
@@ -289,6 +295,8 @@ def render_batch_processor(state: AppState):
htree.commit(snapshot_payload, note=note)
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
state.restored_indicator = None
commit_input.set_value('')
ui.notify('Batch Saved & Snapshot Created!', type='positive')
@@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
def commit(message=None):
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
if message:
ui.notify(message, type='positive')
refresh_list.refresh()
@@ -567,6 +577,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
shifted += 1
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
refresh_list.refresh()
@@ -712,6 +724,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
ui.notify(f'Updated {len(targets)} sequences', type='positive')
if refresh_list:
refresh_list.refresh()

161
tab_projects_ng.py Normal file
View File

@@ -0,0 +1,161 @@
import logging
from pathlib import Path
from nicegui import ui
from state import AppState
from db import ProjectDB
from utils import save_config, sync_to_db, KEY_BATCH_DATA
logger = logging.getLogger(__name__)
def render_projects_tab(state: AppState):
"""Render the Projects management tab."""
# --- DB toggle ---
def on_db_toggle(e):
state.db_enabled = e.value
state.config['db_enabled'] = e.value
save_config(state.current_dir, state.config.get('favorites', []), state.config)
render_project_content.refresh()
ui.switch('Enable Project Database', value=state.db_enabled,
on_change=on_db_toggle).classes('q-mb-md')
@ui.refreshable
def render_project_content():
if not state.db_enabled:
ui.label('Project database is disabled. Enable it above to manage projects.').classes(
'text-caption q-pa-md')
return
if not state.db:
ui.label('Database not initialized.').classes('text-warning q-pa-md')
return
# --- Create project form ---
with ui.card().classes('w-full q-pa-md q-mb-md'):
ui.label('Create New Project').classes('section-header')
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
def create_project():
name = name_input.value.strip()
if not name:
ui.notify('Please enter a project name', type='warning')
return
try:
state.db.create_project(name, str(state.current_dir), desc_input.value.strip())
name_input.set_value('')
desc_input.set_value('')
ui.notify(f'Created project "{name}"', type='positive')
render_project_list.refresh()
except Exception as e:
ui.notify(f'Error: {e}', type='negative')
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
# --- Active project indicator ---
if state.current_project:
ui.label(f'Active Project: {state.current_project}').classes(
'text-bold text-primary q-pa-sm')
# --- Project list ---
@ui.refreshable
def render_project_list():
projects = state.db.list_projects()
if not projects:
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
return
for proj in projects:
is_active = proj['name'] == state.current_project
card_style = 'border-left: 3px solid var(--accent);' if is_active else ''
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style):
with ui.row().classes('w-full items-center'):
with ui.column().classes('col'):
ui.label(proj['name']).classes('text-bold')
if proj['description']:
ui.label(proj['description']).classes('text-caption')
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
files = state.db.list_data_files(proj['id'])
ui.label(f'{len(files)} data file(s)').classes('text-caption')
with ui.row().classes('q-gutter-xs'):
if not is_active:
def activate(name=proj['name']):
state.current_project = name
state.config['current_project'] = name
save_config(state.current_dir,
state.config.get('favorites', []),
state.config)
ui.notify(f'Activated project "{name}"', type='positive')
render_project_list.refresh()
ui.button('Activate', icon='check_circle',
on_click=activate).props('flat dense color=primary')
else:
def deactivate():
state.current_project = ''
state.config['current_project'] = ''
save_config(state.current_dir,
state.config.get('favorites', []),
state.config)
ui.notify('Deactivated project', type='info')
render_project_list.refresh()
ui.button('Deactivate', icon='cancel',
on_click=deactivate).props('flat dense')
def import_folder(pid=proj['id'], pname=proj['name']):
_import_folder(state, pid, pname, render_project_list)
ui.button('Import Folder', icon='folder_open',
on_click=import_folder).props('flat dense')
def delete_proj(name=proj['name']):
state.db.delete_project(name)
if state.current_project == name:
state.current_project = ''
ui.notify(f'Deleted project "{name}"', type='positive')
render_project_list.refresh()
ui.button(icon='delete',
on_click=delete_proj).props('flat dense color=negative')
render_project_list()
render_project_content()
def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
"""Bulk import all .json files from current directory into a project."""
json_files = sorted(state.current_dir.glob('*.json'))
json_files = [f for f in json_files if f.name not in (
'.editor_config.json', '.editor_snippets.json')]
if not json_files:
ui.notify('No JSON files in current directory', type='warning')
return
imported = 0
skipped = 0
for jf in json_files:
file_name = jf.stem
existing = state.db.get_data_file(project_id, file_name)
if existing:
skipped += 1
continue
try:
state.db.import_json_file(project_id, jf)
imported += 1
except Exception as e:
logger.warning(f"Failed to import {jf}: {e}")
msg = f'Imported {imported} file(s)'
if skipped:
msg += f', skipped {skipped} existing'
ui.notify(msg, type='positive')
refresh_fn.refresh()

View File

@@ -4,7 +4,7 @@ import json
from nicegui import ui
from state import AppState
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
def render_raw_editor(state: AppState):
@@ -52,6 +52,8 @@ def render_raw_editor(state: AppState):
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
save_json(file_path, input_data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, input_data)
data.clear()
data.update(input_data)

View File

@@ -5,7 +5,7 @@ from nicegui import ui
from state import AppState
from history_tree import HistoryTree
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE
def _delete_nodes(htree, data, file_path, node_ids):
@@ -134,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn):
def do_batch_delete():
current_valid = state.timeline_selected_nodes & set(htree.nodes.keys())
_delete_nodes(htree, data, file_path, current_valid)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
state.timeline_selected_nodes = set()
ui.notify(
f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!',
@@ -179,7 +181,7 @@ def _find_branch_for_node(htree, node_id):
def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn,
selected):
selected, state=None):
"""Render branch-grouped node manager with restore, rename, delete, and preview."""
ui.label('Manage Version').classes('section-header')
@@ -291,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
htree.nodes[sel_id]['note'] = rename_input.value
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
if state and state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
ui.notify('Label updated', type='positive')
refresh_fn()
@@ -304,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
def delete_selected():
if sel_id in htree.nodes:
_delete_nodes(htree, data, file_path, {sel_id})
if state and state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
ui.notify('Node Deleted', type='positive')
refresh_fn()
@@ -377,7 +383,7 @@ def render_timeline_tab(state: AppState):
_render_node_manager(
all_nodes, htree, data, file_path,
_restore_and_refresh, render_timeline.refresh,
selected)
selected, state=state)
def _toggle_select(nid, checked):
if checked:
@@ -492,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState):
htree.head_id = node['id']
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data)
label = f"{node.get('note', 'Step')} ({node['id'][:4]})"
state.restored_indicator = label
ui.notify('Restored!', type='positive')

286
tests/test_db.py Normal file
View File

@@ -0,0 +1,286 @@
import json
from pathlib import Path
import pytest
from db import ProjectDB
from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE
@pytest.fixture
def db(tmp_path):
"""Create a fresh ProjectDB in a temp directory."""
db_path = tmp_path / "test.db"
pdb = ProjectDB(db_path)
yield pdb
pdb.close()
# ------------------------------------------------------------------
# Projects CRUD
# ------------------------------------------------------------------
class TestProjects:
def test_create_and_get(self, db):
pid = db.create_project("proj1", "/some/path", "A test project")
assert pid > 0
proj = db.get_project("proj1")
assert proj is not None
assert proj["name"] == "proj1"
assert proj["folder_path"] == "/some/path"
assert proj["description"] == "A test project"
def test_list_projects(self, db):
db.create_project("beta", "/b")
db.create_project("alpha", "/a")
projects = db.list_projects()
assert len(projects) == 2
assert projects[0]["name"] == "alpha"
assert projects[1]["name"] == "beta"
def test_get_nonexistent(self, db):
assert db.get_project("nope") is None
def test_delete_project(self, db):
db.create_project("to_delete", "/x")
assert db.delete_project("to_delete") is True
assert db.get_project("to_delete") is None
def test_delete_nonexistent(self, db):
assert db.delete_project("nope") is False
def test_unique_name_constraint(self, db):
db.create_project("dup", "/a")
with pytest.raises(Exception):
db.create_project("dup", "/b")
# ------------------------------------------------------------------
# Data files
# ------------------------------------------------------------------
class TestDataFiles:
def test_create_and_list(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"})
assert df_id > 0
files = db.list_data_files(pid)
assert len(files) == 1
assert files[0]["name"] == "batch_i2v"
assert files[0]["data_type"] == "i2v"
def test_get_data_file(self, db):
pid = db.create_project("p1", "/p1")
db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"})
df = db.get_data_file(pid, "batch_i2v")
assert df is not None
assert df["top_level"] == {"key": "value"}
def test_get_data_file_by_names(self, db):
pid = db.create_project("p1", "/p1")
db.create_data_file(pid, "batch_i2v", "i2v")
df = db.get_data_file_by_names("p1", "batch_i2v")
assert df is not None
assert df["name"] == "batch_i2v"
def test_get_nonexistent_data_file(self, db):
pid = db.create_project("p1", "/p1")
assert db.get_data_file(pid, "nope") is None
def test_unique_constraint(self, db):
pid = db.create_project("p1", "/p1")
db.create_data_file(pid, "batch_i2v", "i2v")
with pytest.raises(Exception):
db.create_data_file(pid, "batch_i2v", "vace")
def test_cascade_delete(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
db.upsert_sequence(df_id, 1, {"prompt": "hello"})
db.save_history_tree(df_id, {"nodes": {}})
db.delete_project("p1")
assert db.get_data_file(pid, "batch_i2v") is None
# ------------------------------------------------------------------
# Sequences
# ------------------------------------------------------------------
class TestSequences:
def test_upsert_and_get(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42})
data = db.get_sequence(df_id, 1)
assert data == {"prompt": "hello", "seed": 42}
def test_upsert_updates_existing(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {"prompt": "v1"})
db.upsert_sequence(df_id, 1, {"prompt": "v2"})
data = db.get_sequence(df_id, 1)
assert data["prompt"] == "v2"
def test_list_sequences(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 3, {"a": 1})
db.upsert_sequence(df_id, 1, {"b": 2})
db.upsert_sequence(df_id, 2, {"c": 3})
seqs = db.list_sequences(df_id)
assert seqs == [1, 2, 3]
def test_get_nonexistent_sequence(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
assert db.get_sequence(df_id, 99) is None
def test_get_sequence_keys(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {
"prompt": "hello",
"seed": 42,
"cfg": 1.5,
"flag": True,
})
keys, types = db.get_sequence_keys(df_id, 1)
assert "prompt" in keys
assert "seed" in keys
idx_prompt = keys.index("prompt")
idx_seed = keys.index("seed")
idx_cfg = keys.index("cfg")
idx_flag = keys.index("flag")
assert types[idx_prompt] == "STRING"
assert types[idx_seed] == "INT"
assert types[idx_cfg] == "FLOAT"
assert types[idx_flag] == "STRING" # bools -> STRING
def test_get_sequence_keys_nonexistent(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
keys, types = db.get_sequence_keys(df_id, 99)
assert keys == []
assert types == []
def test_delete_sequences_for_file(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {"a": 1})
db.upsert_sequence(df_id, 2, {"b": 2})
db.delete_sequences_for_file(df_id)
assert db.list_sequences(df_id) == []
# ------------------------------------------------------------------
# History trees
# ------------------------------------------------------------------
class TestHistoryTrees:
def test_save_and_get(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"}
db.save_history_tree(df_id, tree)
result = db.get_history_tree(df_id)
assert result == tree
def test_upsert_updates(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.save_history_tree(df_id, {"v": 1})
db.save_history_tree(df_id, {"v": 2})
result = db.get_history_tree(df_id)
assert result == {"v": 2}
def test_get_nonexistent(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
assert db.get_history_tree(df_id) is None
# ------------------------------------------------------------------
# Import
# ------------------------------------------------------------------
class TestImport:
def test_import_json_file(self, db, tmp_path):
pid = db.create_project("p1", "/p1")
json_path = tmp_path / "batch_prompt_i2v.json"
data = {
KEY_BATCH_DATA: [
{"sequence_number": 1, "prompt": "hello", "seed": 42},
{"sequence_number": 2, "prompt": "world", "seed": 99},
],
KEY_HISTORY_TREE: {"nodes": {}, "head_id": None},
}
json_path.write_text(json.dumps(data))
df_id = db.import_json_file(pid, json_path, "i2v")
assert df_id > 0
seqs = db.list_sequences(df_id)
assert seqs == [1, 2]
s1 = db.get_sequence(df_id, 1)
assert s1["prompt"] == "hello"
assert s1["seed"] == 42
tree = db.get_history_tree(df_id)
assert tree == {"nodes": {}, "head_id": None}
def test_import_file_name_from_stem(self, db, tmp_path):
pid = db.create_project("p1", "/p1")
json_path = tmp_path / "my_batch.json"
json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]}))
db.import_json_file(pid, json_path)
df = db.get_data_file(pid, "my_batch")
assert df is not None
def test_import_no_batch_data(self, db, tmp_path):
pid = db.create_project("p1", "/p1")
json_path = tmp_path / "simple.json"
json_path.write_text(json.dumps({"prompt": "flat file"}))
df_id = db.import_json_file(pid, json_path)
seqs = db.list_sequences(df_id)
assert seqs == []
# ------------------------------------------------------------------
# Query helpers
# ------------------------------------------------------------------
class TestQueryHelpers:
def test_query_sequence_data(self, db):
pid = db.create_project("myproject", "/mp")
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
result = db.query_sequence_data("myproject", "batch_i2v", 1)
assert result == {"prompt": "test", "seed": 7}
def test_query_sequence_data_not_found(self, db):
assert db.query_sequence_data("nope", "nope", 1) is None
def test_query_sequence_keys(self, db):
pid = db.create_project("myproject", "/mp")
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1)
assert "prompt" in keys
assert "seed" in keys
def test_list_project_files(self, db):
pid = db.create_project("p1", "/p1")
db.create_data_file(pid, "file_a", "i2v")
db.create_data_file(pid, "file_b", "vace")
files = db.list_project_files("p1")
assert len(files) == 2
def test_list_project_sequences(self, db):
pid = db.create_project("p1", "/p1")
df_id = db.create_data_file(pid, "batch", "generic")
db.upsert_sequence(df_id, 1, {})
db.upsert_sequence(df_id, 2, {})
seqs = db.list_project_sequences("p1", "batch")
assert seqs == [1, 2]

View File

@@ -0,0 +1,201 @@
import json
from unittest.mock import patch, MagicMock
from io import BytesIO
import pytest
from project_loader import (
ProjectLoaderDynamic,
ProjectLoaderStandard,
ProjectLoaderVACE,
ProjectLoaderLoRA,
_fetch_json,
_fetch_data,
_fetch_keys,
MAX_DYNAMIC_OUTPUTS,
)
def _mock_urlopen(data: dict):
"""Create a mock context manager for urllib.request.urlopen."""
response = MagicMock()
response.read.return_value = json.dumps(data).encode()
response.__enter__ = lambda s: s
response.__exit__ = MagicMock(return_value=False)
return response
class TestFetchHelpers:
def test_fetch_json_success(self):
data = {"key": "value"}
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)):
result = _fetch_json("http://example.com/api")
assert result == data
def test_fetch_json_failure(self):
import urllib.error
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
result = _fetch_json("http://example.com/api")
assert result == {}
def test_fetch_data_builds_url(self):
data = {"prompt": "hello"}
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1)
assert result == data
called_url = mock.call_args[0][0]
assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url
def test_fetch_keys_builds_url(self):
data = {"keys": ["prompt"], "types": ["STRING"]}
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1)
assert result == data
called_url = mock.call_args[0][0]
assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url
def test_fetch_data_strips_trailing_slash(self):
data = {"prompt": "hello"}
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
_fetch_data("http://localhost:8080/", "proj1", "file1", 1)
called_url = mock.call_args[0][0]
assert "//api" not in called_url
class TestProjectLoaderDynamic:
def test_load_dynamic_with_keys(self):
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_dynamic(
"http://localhost:8080", "proj1", "batch_i2v", 1,
output_keys="prompt,seed,cfg"
)
assert result[0] == "hello"
assert result[1] == 42
assert result[2] == 1.5
assert len(result) == MAX_DYNAMIC_OUTPUTS
def test_load_dynamic_empty_keys(self):
node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
result = node.load_dynamic(
"http://localhost:8080", "proj1", "batch_i2v", 1,
output_keys=""
)
assert all(v == "" for v in result)
def test_load_dynamic_missing_key(self):
node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
result = node.load_dynamic(
"http://localhost:8080", "proj1", "batch_i2v", 1,
output_keys="nonexistent"
)
assert result[0] == ""
def test_load_dynamic_bool_becomes_string(self):
node = ProjectLoaderDynamic()
with patch("project_loader._fetch_data", return_value={"flag": True}):
result = node.load_dynamic(
"http://localhost:8080", "proj1", "batch_i2v", 1,
output_keys="flag"
)
assert result[0] == "true"
def test_input_types_has_manager_url(self):
inputs = ProjectLoaderDynamic.INPUT_TYPES()
assert "manager_url" in inputs["required"]
assert "project_name" in inputs["required"]
assert "file_name" in inputs["required"]
assert "sequence_number" in inputs["required"]
def test_category(self):
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
class TestProjectLoaderStandard:
def test_load_standard(self):
data = {
"general_prompt": "hello",
"general_negative": "bad",
"current_prompt": "specific",
"negative": "neg",
"camera": "pan",
"flf": 0.5,
"seed": 42,
"video file path": "/v.mp4",
"reference image path": "/r.png",
"flf image path": "/f.png",
}
node = ProjectLoaderStandard()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
assert result == ("hello", "bad", "specific", "neg", "pan", 0.5, 42, "/v.mp4", "/r.png", "/f.png")
def test_load_standard_defaults(self):
node = ProjectLoaderStandard()
with patch("project_loader._fetch_data", return_value={}):
result = node.load_standard("http://localhost:8080", "proj1", "batch", 1)
assert result[0] == "" # general_prompt
assert result[5] == 0.0 # flf
assert result[6] == 0 # seed
class TestProjectLoaderVACE:
def test_load_vace(self):
data = {
"general_prompt": "hello",
"general_negative": "bad",
"current_prompt": "specific",
"negative": "neg",
"camera": "pan",
"flf": 0.5,
"seed": 42,
"frame_to_skip": 81,
"input_a_frames": 16,
"input_b_frames": 16,
"reference path": "/ref",
"reference switch": 1,
"vace schedule": 2,
"video file path": "/v.mp4",
"reference image path": "/r.png",
}
node = ProjectLoaderVACE()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_vace("http://localhost:8080", "proj1", "batch", 1)
assert result[7] == 81 # frame_to_skip
assert result[12] == 2 # vace_schedule
class TestProjectLoaderLoRA:
def test_load_loras(self):
data = {
"lora 1 high": "<lora:model1:1.0>",
"lora 1 low": "<lora:model1:0.5>",
"lora 2 high": "",
"lora 2 low": "",
"lora 3 high": "",
"lora 3 low": "",
}
node = ProjectLoaderLoRA()
with patch("project_loader._fetch_data", return_value=data):
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
assert result[0] == "<lora:model1:1.0>"
assert result[1] == "<lora:model1:0.5>"
def test_load_loras_empty(self):
node = ProjectLoaderLoRA()
with patch("project_loader._fetch_data", return_value={}):
result = node.load_loras("http://localhost:8080", "proj1", "batch", 1)
assert all(v == "" for v in result)
class TestNodeMappings:
def test_mappings_exist(self):
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
assert "ProjectLoaderStandard" in PROJECT_NODE_CLASS_MAPPINGS
assert "ProjectLoaderVACE" in PROJECT_NODE_CLASS_MAPPINGS
assert "ProjectLoaderLoRA" in PROJECT_NODE_CLASS_MAPPINGS
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4

View File

@@ -160,6 +160,43 @@ def get_file_mtime(path: str | Path) -> float:
return path.stat().st_mtime
return 0
def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
"""Dual-write helper: sync JSON data to the project database.
Resolves (or creates) the data_file, upserts all sequences from batch_data,
and saves the history_tree.
"""
if not db or not project_name:
return
try:
proj = db.get_project(project_name)
if not proj:
return
file_name = Path(file_path).stem
df = db.get_data_file(proj["id"], file_name)
if not df:
top_level = {k: v for k, v in data.items()
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
df_id = db.create_data_file(proj["id"], file_name, "generic", top_level)
else:
df_id = df["id"]
# Sync sequences
batch_data = data.get(KEY_BATCH_DATA, [])
if isinstance(batch_data, list):
db.delete_sequences_for_file(df_id)
for item in batch_data:
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
db.upsert_sequence(df_id, seq_num, item)
# Sync history tree
history_tree = data.get(KEY_HISTORY_TREE)
if history_tree and isinstance(history_tree, dict):
db.save_history_tree(df_id, history_tree)
except Exception as e:
logger.warning(f"sync_to_db failed: {e}")
def generate_templates(current_dir: Path) -> None:
"""Creates batch template files if folder is empty."""
first = DEFAULTS.copy()

139
web/project_dynamic.js Normal file
View File

@@ -0,0 +1,139 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
app.registerExtension({
name: "json.manager.project.dynamic",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name !== "ProjectLoaderDynamic") return;
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
origOnNodeCreated?.apply(this, arguments);
// Hide internal widgets (managed by JS)
for (const name of ["output_keys", "output_types"]) {
const w = this.widgets?.find(w => w.name === name);
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
}
// Remove all 32 default outputs from Python RETURN_TYPES
while (this.outputs.length > 0) {
this.removeOutput(0);
}
// Add Refresh button
this.addWidget("button", "Refresh Outputs", null, () => {
this.refreshDynamicOutputs();
});
this.setSize(this.computeSize());
};
nodeType.prototype.refreshDynamicOutputs = async function () {
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
const projectWidget = this.widgets?.find(w => w.name === "project_name");
const fileWidget = this.widgets?.find(w => w.name === "file_name");
const seqWidget = this.widgets?.find(w => w.name === "sequence_number");
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
try {
const resp = await api.fetchApi(
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
);
const { keys, types } = await resp.json();
// Store keys and types in hidden widgets for persistence
const okWidget = this.widgets?.find(w => w.name === "output_keys");
if (okWidget) okWidget.value = keys.join(",");
const otWidget = this.widgets?.find(w => w.name === "output_types");
if (otWidget) otWidget.value = types.join(",");
// Build a map of current output names to slot indices
const oldSlots = {};
for (let i = 0; i < this.outputs.length; i++) {
oldSlots[this.outputs[i].name] = i;
}
// Build new outputs, reusing existing slots to preserve links
const newOutputs = [];
for (let k = 0; k < keys.length; k++) {
const key = keys[k];
const type = types[k] || "*";
if (key in oldSlots) {
const slot = this.outputs[oldSlots[key]];
slot.type = type;
newOutputs.push(slot);
delete oldSlots[key];
} else {
newOutputs.push({ name: key, type: type, links: null });
}
}
// Disconnect links on slots that are being removed
for (const name in oldSlots) {
const idx = oldSlots[name];
if (this.outputs[idx]?.links?.length) {
for (const linkId of [...this.outputs[idx].links]) {
this.graph?.removeLink(linkId);
}
}
}
// Reassign the outputs array and fix link slot indices
this.outputs = newOutputs;
if (this.graph) {
for (let i = 0; i < this.outputs.length; i++) {
const links = this.outputs[i].links;
if (!links) continue;
for (const linkId of links) {
const link = this.graph.links[linkId];
if (link) link.origin_slot = i;
}
}
}
this.setSize(this.computeSize());
app.graph.setDirtyCanvas(true, true);
} catch (e) {
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
}
};
// Restore state on workflow load
const origOnConfigure = nodeType.prototype.onConfigure;
nodeType.prototype.onConfigure = function (info) {
origOnConfigure?.apply(this, arguments);
// Hide internal widgets
for (const name of ["output_keys", "output_types"]) {
const w = this.widgets?.find(w => w.name === name);
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
}
const okWidget = this.widgets?.find(w => w.name === "output_keys");
const otWidget = this.widgets?.find(w => w.name === "output_types");
const keys = okWidget?.value
? okWidget.value.split(",").filter(k => k.trim())
: [];
const types = otWidget?.value
? otWidget.value.split(",")
: [];
// Rename and set types to match stored state (preserves links)
for (let i = 0; i < this.outputs.length && i < keys.length; i++) {
this.outputs[i].name = keys[i].trim();
if (types[i]) this.outputs[i].type = types[i];
}
// Remove any extra outputs beyond the key count
while (this.outputs.length > keys.length) {
this.removeOutput(this.outputs.length - 1);
}
this.setSize(this.computeSize());
};
},
});