Compare commits
8 Commits
e2f30b0332
...
4b5fff5c6e
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b5fff5c6e | |||
| d07a308865 | |||
| c4d107206f | |||
| b499eb4dfd | |||
| ba8f104bc1 | |||
| 6b7e9ea682 | |||
| c15bec98ce | |||
| 0d8e84ea36 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
__pycache__/
|
||||
.pytest_cache/
|
||||
.worktrees/
|
||||
@@ -1,4 +1,8 @@
|
||||
from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
NODE_CLASS_MAPPINGS.update(PROJECT_NODE_CLASS_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(PROJECT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
WEB_DIRECTORY = "./web"
|
||||
|
||||
|
||||
80
api_routes.py
Normal file
80
api_routes.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""REST API endpoints for ComfyUI to query project data from SQLite.
|
||||
|
||||
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Query
|
||||
from nicegui import app
|
||||
|
||||
from db import ProjectDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The DB instance is set by register_api_routes()
|
||||
_db: ProjectDB | None = None
|
||||
|
||||
|
||||
def register_api_routes(db: ProjectDB) -> None:
|
||||
"""Register all REST API routes with the NiceGUI/FastAPI app."""
|
||||
global _db
|
||||
_db = db
|
||||
|
||||
app.add_api_route("/api/projects", _list_projects, methods=["GET"])
|
||||
app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"])
|
||||
app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"])
|
||||
app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"])
|
||||
app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"])
|
||||
|
||||
|
||||
def _get_db() -> ProjectDB:
|
||||
if _db is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return _db
|
||||
|
||||
|
||||
def _list_projects() -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
projects = db.list_projects()
|
||||
return {"projects": [p["name"] for p in projects]}
|
||||
|
||||
|
||||
def _list_files(name: str) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
files = db.list_project_files(name)
|
||||
return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]}
|
||||
|
||||
|
||||
def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
seqs = db.list_project_sequences(name, file_name)
|
||||
return {"sequences": seqs}
|
||||
|
||||
|
||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
proj = db.get_project(name)
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
df = db.get_data_file_by_names(name, file_name)
|
||||
if not df:
|
||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||
data = db.get_sequence(df["id"], seq)
|
||||
if data is None:
|
||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||
return data
|
||||
|
||||
|
||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||
db = _get_db()
|
||||
proj = db.get_project(name)
|
||||
if not proj:
|
||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||
df = db.get_data_file_by_names(name, file_name)
|
||||
if not df:
|
||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||
keys, types = db.get_sequence_keys(df["id"], seq)
|
||||
total = db.count_sequences(df["id"])
|
||||
return {"keys": keys, "types": types, "total_sequences": total}
|
||||
349
db.py
Normal file
349
db.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db"
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
folder_path TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS data_files (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
data_type TEXT NOT NULL DEFAULT 'generic',
|
||||
top_level TEXT NOT NULL DEFAULT '{}',
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL,
|
||||
UNIQUE(project_id, name)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sequences (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
||||
sequence_number INTEGER NOT NULL,
|
||||
data TEXT NOT NULL DEFAULT '{}',
|
||||
updated_at REAL NOT NULL,
|
||||
UNIQUE(data_file_id, sequence_number)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS history_trees (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
|
||||
tree_data TEXT NOT NULL DEFAULT '{}',
|
||||
updated_at REAL NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
class ProjectDB:
|
||||
"""SQLite database for project-based data management."""
|
||||
|
||||
def __init__(self, db_path: str | Path | None = None):
|
||||
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
isolation_level=None, # autocommit — explicit BEGIN/COMMIT only
|
||||
)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
||||
self.conn.executescript(SCHEMA_SQL)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Projects CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
|
||||
now = time.time()
|
||||
cur = self.conn.execute(
|
||||
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
(name, folder_path, description, now, now),
|
||||
)
|
||||
self.conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
def list_projects(self) -> list[dict]:
|
||||
rows = self.conn.execute(
|
||||
"SELECT id, name, folder_path, description, created_at, updated_at "
|
||||
"FROM projects ORDER BY name"
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_project(self, name: str) -> dict | None:
|
||||
row = self.conn.execute(
|
||||
"SELECT id, name, folder_path, description, created_at, updated_at "
|
||||
"FROM projects WHERE name = ?",
|
||||
(name,),
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def delete_project(self, name: str) -> bool:
|
||||
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
||||
self.conn.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data files
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_data_file(
|
||||
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
|
||||
) -> int:
|
||||
now = time.time()
|
||||
tl = json.dumps(top_level or {})
|
||||
cur = self.conn.execute(
|
||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(project_id, name, data_type, tl, now, now),
|
||||
)
|
||||
self.conn.commit()
|
||||
return cur.lastrowid
|
||||
|
||||
def list_data_files(self, project_id: int) -> list[dict]:
|
||||
rows = self.conn.execute(
|
||||
"SELECT id, project_id, name, data_type, created_at, updated_at "
|
||||
"FROM data_files WHERE project_id = ? ORDER BY name",
|
||||
(project_id,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
||||
row = self.conn.execute(
|
||||
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
||||
"FROM data_files WHERE project_id = ? AND name = ?",
|
||||
(project_id, name),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
d["top_level"] = json.loads(d["top_level"])
|
||||
return d
|
||||
|
||||
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
|
||||
row = self.conn.execute(
|
||||
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
|
||||
"df.created_at, df.updated_at "
|
||||
"FROM data_files df JOIN projects p ON df.project_id = p.id "
|
||||
"WHERE p.name = ? AND df.name = ?",
|
||||
(project_name, file_name),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
d = dict(row)
|
||||
d["top_level"] = json.loads(d["top_level"])
|
||||
return d
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sequences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
||||
(data_file_id, sequence_number, json.dumps(data), now),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
||||
row = self.conn.execute(
|
||||
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
||||
(data_file_id, sequence_number),
|
||||
).fetchone()
|
||||
return json.loads(row["data"]) if row else None
|
||||
|
||||
def list_sequences(self, data_file_id: int) -> list[int]:
|
||||
rows = self.conn.execute(
|
||||
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
|
||||
(data_file_id,),
|
||||
).fetchall()
|
||||
return [r["sequence_number"] for r in rows]
|
||||
|
||||
def count_sequences(self, data_file_id: int) -> int:
|
||||
"""Return the number of sequences for a data file."""
|
||||
row = self.conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
||||
(data_file_id,),
|
||||
).fetchone()
|
||||
return row["cnt"]
|
||||
|
||||
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
||||
"""Return total sequence count by project and file names."""
|
||||
df = self.get_data_file_by_names(project_name, file_name)
|
||||
if not df:
|
||||
return 0
|
||||
return self.count_sequences(df["id"])
|
||||
|
||||
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||
"""Returns (keys, types) for a sequence's data dict."""
|
||||
data = self.get_sequence(data_file_id, sequence_number)
|
||||
if not data:
|
||||
return [], []
|
||||
keys = []
|
||||
types = []
|
||||
for k, v in data.items():
|
||||
keys.append(k)
|
||||
if isinstance(v, bool):
|
||||
types.append("STRING")
|
||||
elif isinstance(v, int):
|
||||
types.append("INT")
|
||||
elif isinstance(v, float):
|
||||
types.append("FLOAT")
|
||||
else:
|
||||
types.append("STRING")
|
||||
return keys, types
|
||||
|
||||
def delete_sequences_for_file(self, data_file_id: int) -> None:
|
||||
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
|
||||
self.conn.commit()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# History trees
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(data_file_id, json.dumps(tree_data), now),
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_history_tree(self, data_file_id: int) -> dict | None:
|
||||
row = self.conn.execute(
|
||||
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
||||
(data_file_id,),
|
||||
).fetchone()
|
||||
return json.loads(row["tree_data"]) if row else None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Import
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
||||
"""Import a JSON file into the database, splitting batch_data into sequences.
|
||||
|
||||
Safe to call repeatedly — existing data_file is updated, sequences are
|
||||
replaced, and history_tree is upserted. Atomic: all-or-nothing.
|
||||
"""
|
||||
json_path = Path(json_path)
|
||||
data, _ = load_json(json_path)
|
||||
file_name = json_path.stem
|
||||
|
||||
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||
|
||||
self.conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
existing = self.conn.execute(
|
||||
"SELECT id FROM data_files WHERE project_id = ? AND name = ?",
|
||||
(project_id, file_name),
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
df_id = existing["id"]
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
||||
(data_type, json.dumps(top_level), now, df_id),
|
||||
)
|
||||
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||
else:
|
||||
now = time.time()
|
||||
cur = self.conn.execute(
|
||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(project_id, file_name, data_type, json.dumps(top_level), now, now),
|
||||
)
|
||||
df_id = cur.lastrowid
|
||||
|
||||
# Import sequences from batch_data
|
||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||
if isinstance(batch_data, list):
|
||||
for item in batch_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
seq_num = int(item.get("sequence_number", 0))
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
||||
(df_id, seq_num, json.dumps(item), now),
|
||||
)
|
||||
|
||||
# Import history tree
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(df_id, json.dumps(history_tree), now),
|
||||
)
|
||||
|
||||
self.conn.execute("COMMIT")
|
||||
return df_id
|
||||
except Exception:
|
||||
try:
|
||||
self.conn.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query helpers (for REST API)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
|
||||
"""Query a single sequence by project name, file name, and sequence number."""
|
||||
df = self.get_data_file_by_names(project_name, file_name)
|
||||
if not df:
|
||||
return None
|
||||
return self.get_sequence(df["id"], sequence_number)
|
||||
|
||||
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||
"""Query keys and types for a sequence."""
|
||||
df = self.get_data_file_by_names(project_name, file_name)
|
||||
if not df:
|
||||
return [], []
|
||||
return self.get_sequence_keys(df["id"], sequence_number)
|
||||
|
||||
def list_project_files(self, project_name: str) -> list[dict]:
|
||||
"""List data files for a project by name."""
|
||||
proj = self.get_project(project_name)
|
||||
if not proj:
|
||||
return []
|
||||
return self.list_data_files(proj["id"])
|
||||
|
||||
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
|
||||
"""List sequence numbers for a file in a project."""
|
||||
df = self.get_data_file_by_names(project_name, file_name)
|
||||
if not df:
|
||||
return []
|
||||
return self.list_sequences(df["id"])
|
||||
@@ -75,6 +75,8 @@ if PromptServer is not None:
|
||||
except (ValueError, TypeError):
|
||||
seq = 1
|
||||
data = read_json_data(json_path)
|
||||
if not data:
|
||||
return web.json_response({"keys": [], "types": [], "error": "file_not_found"})
|
||||
target = get_batch_item(data, seq)
|
||||
keys = []
|
||||
types = []
|
||||
|
||||
26
main.py
26
main.py
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from nicegui import ui
|
||||
@@ -14,6 +15,18 @@ from tab_batch_ng import render_batch_processor
|
||||
from tab_timeline_ng import render_timeline_tab
|
||||
from tab_raw_ng import render_raw_editor
|
||||
from tab_comfy_ng import render_comfy_monitor
|
||||
from tab_projects_ng import render_projects_tab
|
||||
from db import ProjectDB
|
||||
from api_routes import register_api_routes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Single shared DB instance for both the UI and API routes
|
||||
_shared_db: ProjectDB | None = None
|
||||
try:
|
||||
_shared_db = ProjectDB()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
||||
|
||||
|
||||
@ui.page('/')
|
||||
@@ -156,7 +169,13 @@ def index():
|
||||
config=config,
|
||||
current_dir=Path(config.get('last_dir', str(Path.cwd()))),
|
||||
snippets=load_snippets(),
|
||||
db_enabled=config.get('db_enabled', False),
|
||||
current_project=config.get('current_project', ''),
|
||||
)
|
||||
|
||||
# Use the shared DB instance
|
||||
state.db = _shared_db
|
||||
|
||||
dual_pane = {'active': False, 'state': None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -178,6 +197,7 @@ def index():
|
||||
ui.tab('batch', label='Batch Processor')
|
||||
ui.tab('timeline', label='Timeline')
|
||||
ui.tab('raw', label='Raw Editor')
|
||||
ui.tab('projects', label='Projects')
|
||||
|
||||
with ui.tab_panels(tabs, value='batch').classes('w-full'):
|
||||
with ui.tab_panel('batch'):
|
||||
@@ -186,6 +206,8 @@ def index():
|
||||
render_timeline_tab(state)
|
||||
with ui.tab_panel('raw'):
|
||||
render_raw_editor(state)
|
||||
with ui.tab_panel('projects'):
|
||||
render_projects_tab(state)
|
||||
|
||||
if state.show_comfy_monitor:
|
||||
ui.separator()
|
||||
@@ -481,4 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
||||
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
||||
|
||||
|
||||
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
||||
if _shared_db is not None:
|
||||
register_api_routes(_shared_db)
|
||||
|
||||
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
||||
|
||||
215
project_loader.py
Normal file
215
project_loader.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_DYNAMIC_OUTPUTS = 32
|
||||
|
||||
|
||||
class AnyType(str):
|
||||
"""Universal connector type that matches any ComfyUI type."""
|
||||
def __ne__(self, __value: object) -> bool:
|
||||
return False
|
||||
|
||||
any_type = AnyType("*")
|
||||
|
||||
|
||||
try:
|
||||
from server import PromptServer
|
||||
from aiohttp import web
|
||||
except ImportError:
|
||||
PromptServer = None
|
||||
|
||||
|
||||
def to_float(val: Any) -> float:
|
||||
try:
|
||||
return float(val)
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
def to_int(val: Any) -> int:
|
||||
try:
|
||||
return int(float(val))
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
|
||||
def _fetch_json(url: str) -> dict:
|
||||
"""Fetch JSON from a URL using stdlib urllib.
|
||||
|
||||
On error, returns a dict with an "error" key describing the failure.
|
||||
"""
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=5) as resp:
|
||||
return json.loads(resp.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
# HTTPError is a subclass of URLError — must be caught first
|
||||
body = ""
|
||||
try:
|
||||
raw = e.read()
|
||||
detail = json.loads(raw)
|
||||
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
|
||||
except Exception:
|
||||
body = str(e)
|
||||
logger.warning(f"HTTP {e.code} from {url}: {body}")
|
||||
return {"error": "http_error", "status": e.code, "message": body}
|
||||
except (urllib.error.URLError, OSError) as e:
|
||||
reason = str(e.reason) if hasattr(e, "reason") else str(e)
|
||||
logger.warning(f"Network error fetching {url}: {reason}")
|
||||
return {"error": "network_error", "message": reason}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Invalid JSON from {url}: {e}")
|
||||
return {"error": "parse_error", "message": str(e)}
|
||||
|
||||
|
||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||
"""Fetch sequence data from the NiceGUI REST API."""
|
||||
p = urllib.parse.quote(project, safe='')
|
||||
f = urllib.parse.quote(file, safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
|
||||
return _fetch_json(url)
|
||||
|
||||
|
||||
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||
"""Fetch keys/types from the NiceGUI REST API."""
|
||||
p = urllib.parse.quote(project, safe='')
|
||||
f = urllib.parse.quote(file, safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
|
||||
return _fetch_json(url)
|
||||
|
||||
|
||||
# --- ComfyUI-side proxy endpoints (for frontend JS) ---
|
||||
if PromptServer is not None:
|
||||
@PromptServer.instance.routes.get("/json_manager/list_projects")
|
||||
async def list_projects_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
url = f"{manager_url.rstrip('/')}/api/projects"
|
||||
data = _fetch_json(url)
|
||||
return web.json_response(data)
|
||||
|
||||
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
||||
async def list_project_files_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
||||
data = _fetch_json(url)
|
||||
return web.json_response(data)
|
||||
|
||||
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
||||
async def list_project_sequences_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
||||
data = _fetch_json(url)
|
||||
return web.json_response(data)
|
||||
|
||||
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
|
||||
async def get_project_keys_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
project = request.query.get("project", "")
|
||||
file_name = request.query.get("file", "")
|
||||
try:
|
||||
seq = int(request.query.get("seq", "1"))
|
||||
except (ValueError, TypeError):
|
||||
seq = 1
|
||||
data = _fetch_keys(manager_url, project, file_name, seq)
|
||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
status = data.get("status", 502)
|
||||
return web.json_response(data, status=status)
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 0. DYNAMIC NODE (Project-based)
|
||||
# ==========================================
|
||||
|
||||
class ProjectLoaderDynamic:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||
},
|
||||
"optional": {
|
||||
"output_keys": ("STRING", {"default": ""}),
|
||||
"output_types": ("STRING", {"default": ""}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
||||
RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
||||
FUNCTION = "load_dynamic"
|
||||
CATEGORY = "utils/json/project"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
||||
output_keys="", output_types=""):
|
||||
# Fetch keys metadata (includes total_sequences count)
|
||||
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
||||
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
msg = keys_meta.get("message", "Unknown error")
|
||||
raise RuntimeError(f"Failed to fetch project keys: {msg}")
|
||||
total_sequences = keys_meta.get("total_sequences", 0)
|
||||
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
msg = data.get("message", "Unknown error")
|
||||
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
|
||||
|
||||
# Parse keys — try JSON array first, fall back to comma-split for compat
|
||||
keys = []
|
||||
if output_keys:
|
||||
try:
|
||||
keys = json.loads(output_keys)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
keys = [k.strip() for k in output_keys.split(",") if k.strip()]
|
||||
|
||||
# Parse types for coercion
|
||||
types = []
|
||||
if output_types:
|
||||
try:
|
||||
types = json.loads(output_types)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
types = [t.strip() for t in output_types.split(",")]
|
||||
|
||||
results = []
|
||||
for i, key in enumerate(keys):
|
||||
val = data.get(key, "")
|
||||
declared_type = types[i] if i < len(types) else ""
|
||||
# Coerce based on declared output type when possible
|
||||
if declared_type == "INT":
|
||||
results.append(to_int(val))
|
||||
elif declared_type == "FLOAT":
|
||||
results.append(to_float(val))
|
||||
elif isinstance(val, bool):
|
||||
results.append(str(val).lower())
|
||||
elif isinstance(val, int):
|
||||
results.append(val)
|
||||
elif isinstance(val, float):
|
||||
results.append(val)
|
||||
else:
|
||||
results.append(str(val))
|
||||
|
||||
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
||||
results.append("")
|
||||
|
||||
return (total_sequences,) + tuple(results)
|
||||
|
||||
|
||||
# --- Mappings ---
|
||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||
}
|
||||
|
||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||
}
|
||||
8
state.py
8
state.py
@@ -17,6 +17,11 @@ class AppState:
|
||||
live_toggles: dict = field(default_factory=dict)
|
||||
show_comfy_monitor: bool = True
|
||||
|
||||
# Project DB fields
|
||||
db: Any = None
|
||||
current_project: str = ""
|
||||
db_enabled: bool = False
|
||||
|
||||
# Set at runtime by main.py / tab_comfy_ng.py
|
||||
_render_main: Any = None
|
||||
_load_file: Callable | None = None
|
||||
@@ -29,4 +34,7 @@ class AppState:
|
||||
config=self.config,
|
||||
current_dir=self.current_dir,
|
||||
snippets=self.snippets,
|
||||
db=self.db,
|
||||
current_project=self.current_project,
|
||||
db_enabled=self.db_enabled,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from nicegui import ui
|
||||
|
||||
from state import AppState
|
||||
from utils import (
|
||||
DEFAULTS, save_json, load_json,
|
||||
DEFAULTS, save_json, load_json, sync_to_db,
|
||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
||||
)
|
||||
from history_tree import HistoryTree
|
||||
@@ -161,6 +161,8 @@ def render_batch_processor(state: AppState):
|
||||
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
||||
KEY_PROMPT_HISTORY: []}
|
||||
save_json(new_path, new_data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, new_path, new_data)
|
||||
ui.notify(f'Created {new_name}', type='positive')
|
||||
|
||||
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
||||
@@ -215,6 +217,8 @@ def render_batch_processor(state: AppState):
|
||||
batch_list.append(new_item)
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
render_sequence_list.refresh()
|
||||
|
||||
with ui.row().classes('q-mt-sm'):
|
||||
@@ -250,6 +254,8 @@ def render_batch_processor(state: AppState):
|
||||
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
ui.notify('Sorted by sequence number!', type='positive')
|
||||
render_sequence_list.refresh()
|
||||
|
||||
@@ -289,6 +295,8 @@ def render_batch_processor(state: AppState):
|
||||
htree.commit(snapshot_payload, note=note)
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
state.restored_indicator = None
|
||||
commit_input.set_value('')
|
||||
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
||||
@@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
def commit(message=None):
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
if message:
|
||||
ui.notify(message, type='positive')
|
||||
refresh_list.refresh()
|
||||
@@ -447,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
|
||||
# --- VACE Settings (full width) ---
|
||||
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
||||
_render_vace_settings(i, seq, batch_list, data, file_path, refresh_list)
|
||||
_render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
|
||||
|
||||
# --- LoRA Settings ---
|
||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
||||
@@ -529,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
# VACE Settings sub-section
|
||||
# ======================================================================
|
||||
|
||||
def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
||||
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
||||
# VACE Schedule (needed early for both columns)
|
||||
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
||||
|
||||
@@ -567,6 +577,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
||||
shifted += 1
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
||||
refresh_list.refresh()
|
||||
|
||||
@@ -712,6 +724,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
||||
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
||||
if refresh_list:
|
||||
refresh_list.refresh()
|
||||
|
||||
165
tab_projects_ng.py
Normal file
165
tab_projects_ng.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from nicegui import ui
|
||||
|
||||
from state import AppState
|
||||
from db import ProjectDB
|
||||
from utils import save_config, sync_to_db, KEY_BATCH_DATA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def render_projects_tab(state: AppState):
|
||||
"""Render the Projects management tab."""
|
||||
|
||||
# --- DB toggle ---
|
||||
def on_db_toggle(e):
|
||||
state.db_enabled = e.value
|
||||
state.config['db_enabled'] = e.value
|
||||
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
||||
render_project_content.refresh()
|
||||
|
||||
ui.switch('Enable Project Database', value=state.db_enabled,
|
||||
on_change=on_db_toggle).classes('q-mb-md')
|
||||
|
||||
@ui.refreshable
|
||||
def render_project_content():
|
||||
if not state.db_enabled:
|
||||
ui.label('Project database is disabled. Enable it above to manage projects.').classes(
|
||||
'text-caption q-pa-md')
|
||||
return
|
||||
|
||||
if not state.db:
|
||||
ui.label('Database not initialized.').classes('text-warning q-pa-md')
|
||||
return
|
||||
|
||||
# --- Create project form ---
|
||||
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
||||
ui.label('Create New Project').classes('section-header')
|
||||
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
|
||||
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
|
||||
|
||||
def create_project():
|
||||
name = name_input.value.strip()
|
||||
if not name:
|
||||
ui.notify('Please enter a project name', type='warning')
|
||||
return
|
||||
try:
|
||||
state.db.create_project(name, str(state.current_dir), desc_input.value.strip())
|
||||
name_input.set_value('')
|
||||
desc_input.set_value('')
|
||||
ui.notify(f'Created project "{name}"', type='positive')
|
||||
render_project_list.refresh()
|
||||
except Exception as e:
|
||||
ui.notify(f'Error: {e}', type='negative')
|
||||
|
||||
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
|
||||
|
||||
# --- Active project indicator ---
|
||||
if state.current_project:
|
||||
ui.label(f'Active Project: {state.current_project}').classes(
|
||||
'text-bold text-primary q-pa-sm')
|
||||
|
||||
# --- Project list ---
|
||||
@ui.refreshable
|
||||
def render_project_list():
|
||||
projects = state.db.list_projects()
|
||||
if not projects:
|
||||
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
|
||||
return
|
||||
|
||||
for proj in projects:
|
||||
is_active = proj['name'] == state.current_project
|
||||
card_style = 'border-left: 3px solid var(--accent);' if is_active else ''
|
||||
|
||||
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style):
|
||||
with ui.row().classes('w-full items-center'):
|
||||
with ui.column().classes('col'):
|
||||
ui.label(proj['name']).classes('text-bold')
|
||||
if proj['description']:
|
||||
ui.label(proj['description']).classes('text-caption')
|
||||
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
|
||||
files = state.db.list_data_files(proj['id'])
|
||||
ui.label(f'{len(files)} data file(s)').classes('text-caption')
|
||||
|
||||
with ui.row().classes('q-gutter-xs'):
|
||||
if not is_active:
|
||||
def activate(name=proj['name']):
|
||||
state.current_project = name
|
||||
state.config['current_project'] = name
|
||||
save_config(state.current_dir,
|
||||
state.config.get('favorites', []),
|
||||
state.config)
|
||||
ui.notify(f'Activated project "{name}"', type='positive')
|
||||
render_project_list.refresh()
|
||||
|
||||
ui.button('Activate', icon='check_circle',
|
||||
on_click=activate).props('flat dense color=primary')
|
||||
else:
|
||||
def deactivate():
|
||||
state.current_project = ''
|
||||
state.config['current_project'] = ''
|
||||
save_config(state.current_dir,
|
||||
state.config.get('favorites', []),
|
||||
state.config)
|
||||
ui.notify('Deactivated project', type='info')
|
||||
render_project_list.refresh()
|
||||
|
||||
ui.button('Deactivate', icon='cancel',
|
||||
on_click=deactivate).props('flat dense')
|
||||
|
||||
def import_folder(pid=proj['id'], pname=proj['name']):
|
||||
_import_folder(state, pid, pname, render_project_list)
|
||||
|
||||
ui.button('Import Folder', icon='folder_open',
|
||||
on_click=import_folder).props('flat dense')
|
||||
|
||||
def delete_proj(name=proj['name']):
|
||||
state.db.delete_project(name)
|
||||
if state.current_project == name:
|
||||
state.current_project = ''
|
||||
state.config['current_project'] = ''
|
||||
save_config(state.current_dir,
|
||||
state.config.get('favorites', []),
|
||||
state.config)
|
||||
ui.notify(f'Deleted project "{name}"', type='positive')
|
||||
render_project_list.refresh()
|
||||
|
||||
ui.button(icon='delete',
|
||||
on_click=delete_proj).props('flat dense color=negative')
|
||||
|
||||
render_project_list()
|
||||
|
||||
render_project_content()
|
||||
|
||||
|
||||
def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
|
||||
"""Bulk import all .json files from current directory into a project."""
|
||||
json_files = sorted(state.current_dir.glob('*.json'))
|
||||
json_files = [f for f in json_files if f.name not in (
|
||||
'.editor_config.json', '.editor_snippets.json')]
|
||||
|
||||
if not json_files:
|
||||
ui.notify('No JSON files in current directory', type='warning')
|
||||
return
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
for jf in json_files:
|
||||
file_name = jf.stem
|
||||
existing = state.db.get_data_file(project_id, file_name)
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
state.db.import_json_file(project_id, jf)
|
||||
imported += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import {jf}: {e}")
|
||||
|
||||
msg = f'Imported {imported} file(s)'
|
||||
if skipped:
|
||||
msg += f', skipped {skipped} existing'
|
||||
ui.notify(msg, type='positive')
|
||||
refresh_fn.refresh()
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from nicegui import ui
|
||||
|
||||
from state import AppState
|
||||
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
||||
from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
||||
|
||||
|
||||
def render_raw_editor(state: AppState):
|
||||
@@ -52,6 +52,8 @@ def render_raw_editor(state: AppState):
|
||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
||||
|
||||
save_json(file_path, input_data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, input_data)
|
||||
|
||||
data.clear()
|
||||
data.update(input_data)
|
||||
|
||||
@@ -5,7 +5,7 @@ from nicegui import ui
|
||||
|
||||
from state import AppState
|
||||
from history_tree import HistoryTree
|
||||
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
|
||||
|
||||
def _delete_nodes(htree, data, file_path, node_ids):
|
||||
@@ -134,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn):
|
||||
def do_batch_delete():
|
||||
current_valid = state.timeline_selected_nodes & set(htree.nodes.keys())
|
||||
_delete_nodes(htree, data, file_path, current_valid)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
state.timeline_selected_nodes = set()
|
||||
ui.notify(
|
||||
f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!',
|
||||
@@ -179,7 +181,7 @@ def _find_branch_for_node(htree, node_id):
|
||||
|
||||
|
||||
def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn,
|
||||
selected):
|
||||
selected, state=None):
|
||||
"""Render branch-grouped node manager with restore, rename, delete, and preview."""
|
||||
ui.label('Manage Version').classes('section-header')
|
||||
|
||||
@@ -291,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
|
||||
htree.nodes[sel_id]['note'] = rename_input.value
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
if state and state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
ui.notify('Label updated', type='positive')
|
||||
refresh_fn()
|
||||
|
||||
@@ -304,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
|
||||
def delete_selected():
|
||||
if sel_id in htree.nodes:
|
||||
_delete_nodes(htree, data, file_path, {sel_id})
|
||||
if state and state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
ui.notify('Node Deleted', type='positive')
|
||||
refresh_fn()
|
||||
|
||||
@@ -377,7 +383,7 @@ def render_timeline_tab(state: AppState):
|
||||
_render_node_manager(
|
||||
all_nodes, htree, data, file_path,
|
||||
_restore_and_refresh, render_timeline.refresh,
|
||||
selected)
|
||||
selected, state=state)
|
||||
|
||||
def _toggle_select(nid, checked):
|
||||
if checked:
|
||||
@@ -492,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState):
|
||||
htree.head_id = node['id']
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
sync_to_db(state.db, state.current_project, file_path, data)
|
||||
label = f"{node.get('note', 'Step')} ({node['id'][:4]})"
|
||||
state.restored_indicator = label
|
||||
ui.notify('Restored!', type='positive')
|
||||
|
||||
369
tests/test_db.py
Normal file
369
tests/test_db.py
Normal file
@@ -0,0 +1,369 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from db import ProjectDB
|
||||
from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path):
|
||||
"""Create a fresh ProjectDB in a temp directory."""
|
||||
db_path = tmp_path / "test.db"
|
||||
pdb = ProjectDB(db_path)
|
||||
yield pdb
|
||||
pdb.close()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Projects CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestProjects:
|
||||
def test_create_and_get(self, db):
|
||||
pid = db.create_project("proj1", "/some/path", "A test project")
|
||||
assert pid > 0
|
||||
proj = db.get_project("proj1")
|
||||
assert proj is not None
|
||||
assert proj["name"] == "proj1"
|
||||
assert proj["folder_path"] == "/some/path"
|
||||
assert proj["description"] == "A test project"
|
||||
|
||||
def test_list_projects(self, db):
|
||||
db.create_project("beta", "/b")
|
||||
db.create_project("alpha", "/a")
|
||||
projects = db.list_projects()
|
||||
assert len(projects) == 2
|
||||
assert projects[0]["name"] == "alpha"
|
||||
assert projects[1]["name"] == "beta"
|
||||
|
||||
def test_get_nonexistent(self, db):
|
||||
assert db.get_project("nope") is None
|
||||
|
||||
def test_delete_project(self, db):
|
||||
db.create_project("to_delete", "/x")
|
||||
assert db.delete_project("to_delete") is True
|
||||
assert db.get_project("to_delete") is None
|
||||
|
||||
def test_delete_nonexistent(self, db):
|
||||
assert db.delete_project("nope") is False
|
||||
|
||||
def test_unique_name_constraint(self, db):
|
||||
db.create_project("dup", "/a")
|
||||
with pytest.raises(Exception):
|
||||
db.create_project("dup", "/b")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Data files
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestDataFiles:
|
||||
def test_create_and_list(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"})
|
||||
assert df_id > 0
|
||||
files = db.list_data_files(pid)
|
||||
assert len(files) == 1
|
||||
assert files[0]["name"] == "batch_i2v"
|
||||
assert files[0]["data_type"] == "i2v"
|
||||
|
||||
def test_get_data_file(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"})
|
||||
df = db.get_data_file(pid, "batch_i2v")
|
||||
assert df is not None
|
||||
assert df["top_level"] == {"key": "value"}
|
||||
|
||||
def test_get_data_file_by_names(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
db.create_data_file(pid, "batch_i2v", "i2v")
|
||||
df = db.get_data_file_by_names("p1", "batch_i2v")
|
||||
assert df is not None
|
||||
assert df["name"] == "batch_i2v"
|
||||
|
||||
def test_get_nonexistent_data_file(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
assert db.get_data_file(pid, "nope") is None
|
||||
|
||||
def test_unique_constraint(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
db.create_data_file(pid, "batch_i2v", "i2v")
|
||||
with pytest.raises(Exception):
|
||||
db.create_data_file(pid, "batch_i2v", "vace")
|
||||
|
||||
def test_cascade_delete(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "hello"})
|
||||
db.save_history_tree(df_id, {"nodes": {}})
|
||||
db.delete_project("p1")
|
||||
assert db.get_data_file(pid, "batch_i2v") is None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sequences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestSequences:
|
||||
def test_upsert_and_get(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42})
|
||||
data = db.get_sequence(df_id, 1)
|
||||
assert data == {"prompt": "hello", "seed": 42}
|
||||
|
||||
def test_upsert_updates_existing(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "v1"})
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "v2"})
|
||||
data = db.get_sequence(df_id, 1)
|
||||
assert data["prompt"] == "v2"
|
||||
|
||||
def test_list_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 3, {"a": 1})
|
||||
db.upsert_sequence(df_id, 1, {"b": 2})
|
||||
db.upsert_sequence(df_id, 2, {"c": 3})
|
||||
seqs = db.list_sequences(df_id)
|
||||
assert seqs == [1, 2, 3]
|
||||
|
||||
def test_get_nonexistent_sequence(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
assert db.get_sequence(df_id, 99) is None
|
||||
|
||||
def test_get_sequence_keys(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {
|
||||
"prompt": "hello",
|
||||
"seed": 42,
|
||||
"cfg": 1.5,
|
||||
"flag": True,
|
||||
})
|
||||
keys, types = db.get_sequence_keys(df_id, 1)
|
||||
assert "prompt" in keys
|
||||
assert "seed" in keys
|
||||
idx_prompt = keys.index("prompt")
|
||||
idx_seed = keys.index("seed")
|
||||
idx_cfg = keys.index("cfg")
|
||||
idx_flag = keys.index("flag")
|
||||
assert types[idx_prompt] == "STRING"
|
||||
assert types[idx_seed] == "INT"
|
||||
assert types[idx_cfg] == "FLOAT"
|
||||
assert types[idx_flag] == "STRING" # bools -> STRING
|
||||
|
||||
def test_get_sequence_keys_nonexistent(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
keys, types = db.get_sequence_keys(df_id, 99)
|
||||
assert keys == []
|
||||
assert types == []
|
||||
|
||||
def test_delete_sequences_for_file(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||
db.delete_sequences_for_file(df_id)
|
||||
assert db.list_sequences(df_id) == []
|
||||
|
||||
def test_count_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
assert db.count_sequences(df_id) == 0
|
||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||
db.upsert_sequence(df_id, 3, {"c": 3})
|
||||
assert db.count_sequences(df_id) == 3
|
||||
|
||||
def test_query_total_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||
assert db.query_total_sequences("p1", "batch") == 2
|
||||
|
||||
def test_query_total_sequences_nonexistent(self, db):
|
||||
assert db.query_total_sequences("nope", "nope") == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# History trees
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestHistoryTrees:
|
||||
def test_save_and_get(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"}
|
||||
db.save_history_tree(df_id, tree)
|
||||
result = db.get_history_tree(df_id)
|
||||
assert result == tree
|
||||
|
||||
def test_upsert_updates(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.save_history_tree(df_id, {"v": 1})
|
||||
db.save_history_tree(df_id, {"v": 2})
|
||||
result = db.get_history_tree(df_id)
|
||||
assert result == {"v": 2}
|
||||
|
||||
def test_get_nonexistent(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
assert db.get_history_tree(df_id) is None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Import
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestImport:
|
||||
def test_import_json_file(self, db, tmp_path):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "batch_prompt_i2v.json"
|
||||
data = {
|
||||
KEY_BATCH_DATA: [
|
||||
{"sequence_number": 1, "prompt": "hello", "seed": 42},
|
||||
{"sequence_number": 2, "prompt": "world", "seed": 99},
|
||||
],
|
||||
KEY_HISTORY_TREE: {"nodes": {}, "head_id": None},
|
||||
}
|
||||
json_path.write_text(json.dumps(data))
|
||||
|
||||
df_id = db.import_json_file(pid, json_path, "i2v")
|
||||
assert df_id > 0
|
||||
|
||||
seqs = db.list_sequences(df_id)
|
||||
assert seqs == [1, 2]
|
||||
|
||||
s1 = db.get_sequence(df_id, 1)
|
||||
assert s1["prompt"] == "hello"
|
||||
assert s1["seed"] == 42
|
||||
|
||||
tree = db.get_history_tree(df_id)
|
||||
assert tree == {"nodes": {}, "head_id": None}
|
||||
|
||||
def test_import_file_name_from_stem(self, db, tmp_path):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "my_batch.json"
|
||||
json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]}))
|
||||
db.import_json_file(pid, json_path)
|
||||
df = db.get_data_file(pid, "my_batch")
|
||||
assert df is not None
|
||||
|
||||
def test_import_no_batch_data(self, db, tmp_path):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "simple.json"
|
||||
json_path.write_text(json.dumps({"prompt": "flat file"}))
|
||||
df_id = db.import_json_file(pid, json_path)
|
||||
seqs = db.list_sequences(df_id)
|
||||
assert seqs == []
|
||||
|
||||
def test_reimport_updates_existing(self, db, tmp_path):
|
||||
"""Re-importing the same file should update data, not crash."""
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "batch.json"
|
||||
|
||||
# First import
|
||||
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
|
||||
json_path.write_text(json.dumps(data_v1))
|
||||
df_id_1 = db.import_json_file(pid, json_path, "i2v")
|
||||
|
||||
# Second import (same file, updated data)
|
||||
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
|
||||
json_path.write_text(json.dumps(data_v2))
|
||||
df_id_2 = db.import_json_file(pid, json_path, "vace")
|
||||
|
||||
# Should reuse the same data_file row
|
||||
assert df_id_1 == df_id_2
|
||||
# Data type should be updated
|
||||
df = db.get_data_file(pid, "batch")
|
||||
assert df["data_type"] == "vace"
|
||||
# Sequences should reflect v2
|
||||
seqs = db.list_sequences(df_id_2)
|
||||
assert seqs == [1, 2]
|
||||
s1 = db.get_sequence(df_id_2, 1)
|
||||
assert s1["prompt"] == "v2"
|
||||
|
||||
def test_import_skips_non_dict_batch_items(self, db, tmp_path):
|
||||
"""Non-dict elements in batch_data should be silently skipped, not crash."""
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "mixed.json"
|
||||
data = {KEY_BATCH_DATA: [
|
||||
{"sequence_number": 1, "prompt": "valid"},
|
||||
"not a dict",
|
||||
42,
|
||||
None,
|
||||
{"sequence_number": 3, "prompt": "also valid"},
|
||||
]}
|
||||
json_path.write_text(json.dumps(data))
|
||||
df_id = db.import_json_file(pid, json_path)
|
||||
|
||||
seqs = db.list_sequences(df_id)
|
||||
assert seqs == [1, 3]
|
||||
|
||||
def test_import_atomic_on_error(self, db, tmp_path):
|
||||
"""If import fails partway, no partial data should be committed."""
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "batch.json"
|
||||
data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]}
|
||||
json_path.write_text(json.dumps(data))
|
||||
db.import_json_file(pid, json_path)
|
||||
|
||||
# Now try to import with bad data that will cause an error
|
||||
# (overwrite the file with invalid sequence_number that causes int() to fail)
|
||||
bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]}
|
||||
json_path.write_text(json.dumps(bad_data))
|
||||
with pytest.raises(ValueError):
|
||||
db.import_json_file(pid, json_path)
|
||||
|
||||
# Original data should still be intact (rollback worked)
|
||||
df = db.get_data_file(pid, "batch")
|
||||
assert df is not None
|
||||
s1 = db.get_sequence(df["id"], 1)
|
||||
assert s1["prompt"] == "hello"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class TestQueryHelpers:
|
||||
def test_query_sequence_data(self, db):
|
||||
pid = db.create_project("myproject", "/mp")
|
||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
||||
result = db.query_sequence_data("myproject", "batch_i2v", 1)
|
||||
assert result == {"prompt": "test", "seed": 7}
|
||||
|
||||
def test_query_sequence_data_not_found(self, db):
|
||||
assert db.query_sequence_data("nope", "nope", 1) is None
|
||||
|
||||
def test_query_sequence_keys(self, db):
|
||||
pid = db.create_project("myproject", "/mp")
|
||||
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
||||
keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1)
|
||||
assert "prompt" in keys
|
||||
assert "seed" in keys
|
||||
|
||||
def test_list_project_files(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
db.create_data_file(pid, "file_a", "i2v")
|
||||
db.create_data_file(pid, "file_b", "vace")
|
||||
files = db.list_project_files("p1")
|
||||
assert len(files) == 2
|
||||
|
||||
def test_list_project_sequences(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.upsert_sequence(df_id, 1, {})
|
||||
db.upsert_sequence(df_id, 2, {})
|
||||
seqs = db.list_project_sequences("p1", "batch")
|
||||
assert seqs == [1, 2]
|
||||
211
tests/test_project_loader.py
Normal file
211
tests/test_project_loader.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from project_loader import (
|
||||
ProjectLoaderDynamic,
|
||||
_fetch_json,
|
||||
_fetch_data,
|
||||
_fetch_keys,
|
||||
MAX_DYNAMIC_OUTPUTS,
|
||||
)
|
||||
|
||||
|
||||
def _mock_urlopen(data: dict):
|
||||
"""Create a mock context manager for urllib.request.urlopen."""
|
||||
response = MagicMock()
|
||||
response.read.return_value = json.dumps(data).encode()
|
||||
response.__enter__ = lambda s: s
|
||||
response.__exit__ = MagicMock(return_value=False)
|
||||
return response
|
||||
|
||||
|
||||
class TestFetchHelpers:
|
||||
def test_fetch_json_success(self):
|
||||
data = {"key": "value"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result == data
|
||||
|
||||
def test_fetch_json_network_error(self):
|
||||
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result["error"] == "network_error"
|
||||
assert "connection refused" in result["message"]
|
||||
|
||||
def test_fetch_json_http_error(self):
|
||||
import urllib.error
|
||||
err = urllib.error.HTTPError(
|
||||
"http://example.com/api", 404, "Not Found", {},
|
||||
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
|
||||
)
|
||||
with patch("project_loader.urllib.request.urlopen", side_effect=err):
|
||||
result = _fetch_json("http://example.com/api")
|
||||
assert result["error"] == "http_error"
|
||||
assert result["status"] == 404
|
||||
assert "not found" in result["message"].lower()
|
||||
|
||||
def test_fetch_data_builds_url(self):
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||
assert result == data
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url
|
||||
|
||||
def test_fetch_keys_builds_url(self):
|
||||
data = {"keys": ["prompt"], "types": ["STRING"]}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||
assert result == data
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url
|
||||
|
||||
def test_fetch_data_strips_trailing_slash(self):
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
_fetch_data("http://localhost:8080/", "proj1", "file1", 1)
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "//api" not in called_url
|
||||
|
||||
def test_fetch_data_encodes_special_chars(self):
|
||||
"""Project/file names with spaces or special chars should be percent-encoded."""
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "my%20project" in called_url
|
||||
assert "batch%20file" in called_url
|
||||
assert " " not in called_url.split("?")[0] # no raw spaces in path
|
||||
|
||||
|
||||
class TestProjectLoaderDynamic:
|
||||
def _keys_meta(self, total=5):
|
||||
return {"keys": [], "types": [], "total_sequences": total}
|
||||
|
||||
def test_load_dynamic_with_keys(self):
|
||||
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="prompt,seed,cfg"
|
||||
)
|
||||
assert result[0] == 5 # total_sequences
|
||||
assert result[1] == "hello"
|
||||
assert result[2] == 42
|
||||
assert result[3] == 1.5
|
||||
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
|
||||
|
||||
def test_load_dynamic_with_json_encoded_keys(self):
|
||||
"""JSON-encoded output_keys should be parsed correctly."""
|
||||
import json as _json
|
||||
data = {"my,key": "comma_val", "normal": "ok"}
|
||||
node = ProjectLoaderDynamic()
|
||||
keys_json = _json.dumps(["my,key", "normal"])
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=keys_json
|
||||
)
|
||||
assert result[1] == "comma_val"
|
||||
assert result[2] == "ok"
|
||||
|
||||
def test_load_dynamic_type_coercion(self):
|
||||
"""output_types should coerce values to declared types."""
|
||||
import json as _json
|
||||
data = {"seed": "42", "cfg": "1.5", "prompt": "hello"}
|
||||
node = ProjectLoaderDynamic()
|
||||
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
||||
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=keys_json, output_types=types_json
|
||||
)
|
||||
assert result[1] == 42 # string "42" coerced to int
|
||||
assert result[2] == 1.5 # string "1.5" coerced to float
|
||||
assert result[3] == "hello" # string stays string
|
||||
|
||||
def test_load_dynamic_empty_keys(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=""
|
||||
)
|
||||
# Slot 0 is total_sequences (INT), rest are empty strings
|
||||
assert result[0] == 5
|
||||
assert all(v == "" for v in result[1:])
|
||||
|
||||
def test_load_dynamic_missing_key(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="nonexistent"
|
||||
)
|
||||
assert result[1] == ""
|
||||
|
||||
def test_load_dynamic_bool_becomes_string(self):
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys="flag"
|
||||
)
|
||||
assert result[1] == "true"
|
||||
|
||||
def test_load_dynamic_returns_total_sequences(self):
|
||||
"""total_sequences should be the first output from keys metadata."""
|
||||
node = ProjectLoaderDynamic()
|
||||
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.load_dynamic(
|
||||
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||
output_keys=""
|
||||
)
|
||||
assert result[0] == 42
|
||||
|
||||
def test_load_dynamic_raises_on_network_error(self):
|
||||
"""Network errors from _fetch_keys should raise RuntimeError."""
|
||||
node = ProjectLoaderDynamic()
|
||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||
with patch("project_loader._fetch_keys", return_value=error_resp):
|
||||
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
|
||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||
|
||||
def test_load_dynamic_raises_on_data_fetch_error(self):
|
||||
"""Network errors from _fetch_data should raise RuntimeError."""
|
||||
node = ProjectLoaderDynamic()
|
||||
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
|
||||
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
|
||||
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||
|
||||
def test_input_types_has_manager_url(self):
|
||||
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
||||
assert "manager_url" in inputs["required"]
|
||||
assert "project_name" in inputs["required"]
|
||||
assert "file_name" in inputs["required"]
|
||||
assert "sequence_number" in inputs["required"]
|
||||
|
||||
def test_category(self):
|
||||
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||
|
||||
|
||||
class TestNodeMappings:
|
||||
def test_mappings_exist(self):
|
||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
|
||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
|
||||
74
utils.py
74
utils.py
@@ -160,6 +160,80 @@ def get_file_mtime(path: str | Path) -> float:
|
||||
return path.stat().st_mtime
|
||||
return 0
|
||||
|
||||
def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||
"""Dual-write helper: sync JSON data to the project database.
|
||||
|
||||
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
||||
and saves the history_tree. All writes happen in a single transaction.
|
||||
"""
|
||||
if not db or not project_name:
|
||||
return
|
||||
try:
|
||||
proj = db.get_project(project_name)
|
||||
if not proj:
|
||||
return
|
||||
file_name = Path(file_path).stem
|
||||
|
||||
# Use a single transaction for atomicity
|
||||
db.conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
df = db.get_data_file(proj["id"], file_name)
|
||||
top_level = {k: v for k, v in data.items()
|
||||
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||
if not df:
|
||||
now = __import__('time').time()
|
||||
cur = db.conn.execute(
|
||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
|
||||
)
|
||||
df_id = cur.lastrowid
|
||||
else:
|
||||
df_id = df["id"]
|
||||
# Update top_level metadata
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
||||
(json.dumps(top_level), now, df_id),
|
||||
)
|
||||
|
||||
# Sync sequences
|
||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||
if isinstance(batch_data, list):
|
||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||
for item in batch_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(df_id, seq_num, json.dumps(item), now),
|
||||
)
|
||||
|
||||
# Sync history tree
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(df_id, json.dumps(history_tree), now),
|
||||
)
|
||||
|
||||
db.conn.execute("COMMIT")
|
||||
except Exception:
|
||||
try:
|
||||
db.conn.execute("ROLLBACK")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"sync_to_db failed: {e}")
|
||||
|
||||
|
||||
def generate_templates(current_dir: Path) -> None:
|
||||
"""Creates batch template files if folder is empty."""
|
||||
first = DEFAULTS.copy()
|
||||
|
||||
@@ -17,17 +17,31 @@ app.registerExtension({
|
||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
||||
}
|
||||
|
||||
// Remove all 32 default outputs from Python RETURN_TYPES
|
||||
while (this.outputs.length > 0) {
|
||||
this.removeOutput(0);
|
||||
}
|
||||
// Do NOT remove default outputs synchronously here.
|
||||
// During graph loading, ComfyUI creates all nodes (firing onNodeCreated)
|
||||
// before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve
|
||||
// links to our outputs during their configure step. If we remove outputs
|
||||
// here, those nodes find no output slot and error out.
|
||||
//
|
||||
// Instead, defer cleanup: for loaded workflows onConfigure sets _configured
|
||||
// before this runs; for new nodes the defaults are cleaned up.
|
||||
this._configured = false;
|
||||
|
||||
// Add Refresh button
|
||||
this.addWidget("button", "Refresh Outputs", null, () => {
|
||||
this.refreshDynamicOutputs();
|
||||
});
|
||||
|
||||
queueMicrotask(() => {
|
||||
if (!this._configured) {
|
||||
// New node (not loading) — remove the 32 Python default outputs
|
||||
while (this.outputs.length > 0) {
|
||||
this.removeOutput(0);
|
||||
}
|
||||
this.setSize(this.computeSize());
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
||||
@@ -39,7 +53,14 @@ app.registerExtension({
|
||||
const resp = await api.fetchApi(
|
||||
`/json_manager/get_keys?path=${encodeURIComponent(pathWidget.value)}&sequence_number=${seqWidget?.value || 1}`
|
||||
);
|
||||
const { keys, types } = await resp.json();
|
||||
const data = await resp.json();
|
||||
const { keys, types } = data;
|
||||
|
||||
// If the file wasn't found, keep existing outputs and links intact
|
||||
if (data.error === "file_not_found") {
|
||||
console.warn("[JSONLoaderDynamic] File not found, keeping existing outputs:", pathWidget.value);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store keys and types in hidden widgets for persistence
|
||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||
@@ -82,7 +103,6 @@ app.registerExtension({
|
||||
|
||||
// Reassign the outputs array and fix link slot indices
|
||||
this.outputs = newOutputs;
|
||||
// Update link origin_slot to match new positions
|
||||
if (this.graph) {
|
||||
for (let i = 0; i < this.outputs.length; i++) {
|
||||
const links = this.outputs[i].links;
|
||||
@@ -105,6 +125,7 @@ app.registerExtension({
|
||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||
nodeType.prototype.onConfigure = function (info) {
|
||||
origOnConfigure?.apply(this, arguments);
|
||||
this._configured = true;
|
||||
|
||||
// Hide internal widgets
|
||||
for (const name of ["output_keys", "output_types"]) {
|
||||
@@ -122,6 +143,7 @@ app.registerExtension({
|
||||
? otWidget.value.split(",")
|
||||
: [];
|
||||
|
||||
if (keys.length > 0) {
|
||||
// On load, LiteGraph already restored serialized outputs with links.
|
||||
// Rename and set types to match stored state (preserves links).
|
||||
for (let i = 0; i < this.outputs.length && i < keys.length; i++) {
|
||||
@@ -133,6 +155,12 @@ app.registerExtension({
|
||||
while (this.outputs.length > keys.length) {
|
||||
this.removeOutput(this.outputs.length - 1);
|
||||
}
|
||||
} else if (this.outputs.length > 0) {
|
||||
// Widget values empty but serialized outputs exist — sync widgets
|
||||
// from the outputs LiteGraph already restored (fallback).
|
||||
if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(",");
|
||||
if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(",");
|
||||
}
|
||||
|
||||
this.setSize(this.computeSize());
|
||||
};
|
||||
|
||||
272
web/project_dynamic.js
Normal file
272
web/project_dynamic.js
Normal file
@@ -0,0 +1,272 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "json.manager.project.dynamic",
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeData.name !== "ProjectLoaderDynamic") return;
|
||||
|
||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
origOnNodeCreated?.apply(this, arguments);
|
||||
|
||||
// Hide internal widgets (managed by JS)
|
||||
for (const name of ["output_keys", "output_types"]) {
|
||||
const w = this.widgets?.find(w => w.name === name);
|
||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
||||
}
|
||||
|
||||
// Do NOT remove default outputs synchronously here.
|
||||
// During graph loading, ComfyUI creates all nodes (firing onNodeCreated)
|
||||
// before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve
|
||||
// links to our outputs during their configure step. If we remove outputs
|
||||
// here, those nodes find no output slot and error out.
|
||||
//
|
||||
// Instead, defer cleanup: for loaded workflows onConfigure sets _configured
|
||||
// before this runs; for new nodes the defaults are cleaned up.
|
||||
this._configured = false;
|
||||
|
||||
// Add Refresh button
|
||||
this.addWidget("button", "Refresh Outputs", null, () => {
|
||||
this.refreshDynamicOutputs();
|
||||
});
|
||||
|
||||
// Auto-refresh with 500ms debounce on widget changes
|
||||
this._refreshTimer = null;
|
||||
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
|
||||
for (const widgetName of autoRefreshWidgets) {
|
||||
const w = this.widgets?.find(w => w.name === widgetName);
|
||||
if (w) {
|
||||
const origCallback = w.callback;
|
||||
const node = this;
|
||||
w.callback = function (...args) {
|
||||
origCallback?.apply(this, args);
|
||||
clearTimeout(node._refreshTimer);
|
||||
node._refreshTimer = setTimeout(() => {
|
||||
node.refreshDynamicOutputs();
|
||||
}, 500);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
queueMicrotask(() => {
|
||||
if (!this._configured) {
|
||||
// New node (not loading) — remove the Python default outputs
|
||||
// and add only the fixed total_sequences slot
|
||||
while (this.outputs.length > 0) {
|
||||
this.removeOutput(0);
|
||||
}
|
||||
this.addOutput("total_sequences", "INT");
|
||||
this.setSize(this.computeSize());
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
nodeType.prototype._setStatus = function (status, message) {
|
||||
const baseTitle = "Project Loader (Dynamic)";
|
||||
if (status === "ok") {
|
||||
this.title = baseTitle;
|
||||
this.color = undefined;
|
||||
this.bgcolor = undefined;
|
||||
} else if (status === "error") {
|
||||
this.title = baseTitle + " - ERROR";
|
||||
this.color = "#ff4444";
|
||||
this.bgcolor = "#331111";
|
||||
if (message) this.title = baseTitle + ": " + message;
|
||||
} else if (status === "loading") {
|
||||
this.title = baseTitle + " - Loading...";
|
||||
}
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
};
|
||||
|
||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
||||
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
||||
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
||||
const fileWidget = this.widgets?.find(w => w.name === "file_name");
|
||||
const seqWidget = this.widgets?.find(w => w.name === "sequence_number");
|
||||
|
||||
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
||||
|
||||
this._setStatus("loading");
|
||||
|
||||
try {
|
||||
const resp = await api.fetchApi(
|
||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
||||
);
|
||||
|
||||
if (!resp.ok) {
|
||||
let errorMsg = `HTTP ${resp.status}`;
|
||||
try {
|
||||
const errData = await resp.json();
|
||||
if (errData.message) errorMsg = errData.message;
|
||||
} catch (_) {}
|
||||
this._setStatus("error", errorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
const data = await resp.json();
|
||||
const keys = data.keys;
|
||||
const types = data.types;
|
||||
|
||||
// If the API returned an error or missing data, keep existing outputs and links intact
|
||||
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
||||
const errMsg = data.error ? data.message || data.error : "Missing keys/types";
|
||||
this._setStatus("error", errMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Store keys and types in hidden widgets for persistence (JSON-encoded)
|
||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||
if (okWidget) okWidget.value = JSON.stringify(keys);
|
||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||
if (otWidget) otWidget.value = JSON.stringify(types);
|
||||
|
||||
// Slot 0 is always total_sequences (INT) — ensure it exists
|
||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||
}
|
||||
this.outputs[0].type = "INT";
|
||||
|
||||
// Build a map of current dynamic output names to slot indices (skip slot 0)
|
||||
const oldSlots = {};
|
||||
for (let i = 1; i < this.outputs.length; i++) {
|
||||
oldSlots[this.outputs[i].name] = i;
|
||||
}
|
||||
|
||||
// Build new dynamic outputs, reusing existing slots to preserve links
|
||||
const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
|
||||
for (let k = 0; k < keys.length; k++) {
|
||||
const key = keys[k];
|
||||
const type = types[k] || "*";
|
||||
if (key in oldSlots) {
|
||||
const slot = this.outputs[oldSlots[key]];
|
||||
slot.type = type;
|
||||
newOutputs.push(slot);
|
||||
delete oldSlots[key];
|
||||
} else {
|
||||
newOutputs.push({ name: key, type: type, links: null });
|
||||
}
|
||||
}
|
||||
|
||||
// Disconnect links on slots that are being removed
|
||||
for (const name in oldSlots) {
|
||||
const idx = oldSlots[name];
|
||||
if (this.outputs[idx]?.links?.length) {
|
||||
for (const linkId of [...this.outputs[idx].links]) {
|
||||
this.graph?.removeLink(linkId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reassign the outputs array and fix link slot indices
|
||||
this.outputs = newOutputs;
|
||||
if (this.graph) {
|
||||
for (let i = 0; i < this.outputs.length; i++) {
|
||||
const links = this.outputs[i].links;
|
||||
if (!links) continue;
|
||||
for (const linkId of links) {
|
||||
const link = this.graph.links[linkId];
|
||||
if (link) link.origin_slot = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this._setStatus("ok");
|
||||
this.setSize(this.computeSize());
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
} catch (e) {
|
||||
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
||||
this._setStatus("error", "Server unreachable");
|
||||
}
|
||||
};
|
||||
|
||||
// Restore state on workflow load
|
||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||
nodeType.prototype.onConfigure = function (info) {
|
||||
origOnConfigure?.apply(this, arguments);
|
||||
this._configured = true;
|
||||
|
||||
// Hide internal widgets
|
||||
for (const name of ["output_keys", "output_types"]) {
|
||||
const w = this.widgets?.find(w => w.name === name);
|
||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
||||
}
|
||||
|
||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||
|
||||
// Parse keys/types — try JSON array first, fall back to comma-split
|
||||
let keys = [];
|
||||
if (okWidget?.value) {
|
||||
try { keys = JSON.parse(okWidget.value); } catch (_) {
|
||||
keys = okWidget.value.split(",").map(k => k.trim()).filter(Boolean);
|
||||
}
|
||||
}
|
||||
let types = [];
|
||||
if (otWidget?.value) {
|
||||
try { types = JSON.parse(otWidget.value); } catch (_) {
|
||||
types = otWidget.value.split(",").map(t => t.trim()).filter(Boolean);
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure slot 0 is total_sequences (INT)
|
||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||
// LiteGraph restores links AFTER onConfigure, so graph.links is
|
||||
// empty here. Defer link fixup to a microtask that runs after the
|
||||
// synchronous graph.configure() finishes (including link restoration).
|
||||
// We must also rebuild output.links arrays because LiteGraph will
|
||||
// place link IDs on the wrong outputs (shifted by the unshift above).
|
||||
const node = this;
|
||||
queueMicrotask(() => {
|
||||
if (!node.graph) return;
|
||||
// Clear all output.links — they were populated at old indices
|
||||
for (const output of node.outputs) {
|
||||
output.links = null;
|
||||
}
|
||||
// Rebuild from graph.links with corrected origin_slot (+1)
|
||||
for (const linkId in node.graph.links) {
|
||||
const link = node.graph.links[linkId];
|
||||
if (!link || link.origin_id !== node.id) continue;
|
||||
link.origin_slot += 1;
|
||||
const output = node.outputs[link.origin_slot];
|
||||
if (output) {
|
||||
if (!output.links) output.links = [];
|
||||
output.links.push(link.id);
|
||||
}
|
||||
}
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
});
|
||||
}
|
||||
this.outputs[0].type = "INT";
|
||||
this.outputs[0].name = "total_sequences";
|
||||
|
||||
if (keys.length > 0) {
|
||||
// On load, LiteGraph already restored serialized outputs with links.
|
||||
// Dynamic outputs start at slot 1. Rename and set types to match stored state.
|
||||
for (let i = 0; i < keys.length; i++) {
|
||||
const slotIdx = i + 1; // offset by 1 for total_sequences
|
||||
if (slotIdx < this.outputs.length) {
|
||||
this.outputs[slotIdx].name = keys[i];
|
||||
if (types[i]) this.outputs[slotIdx].type = types[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Remove any extra outputs beyond keys + total_sequences
|
||||
while (this.outputs.length > keys.length + 1) {
|
||||
this.removeOutput(this.outputs.length - 1);
|
||||
}
|
||||
} else if (this.outputs.length > 1) {
|
||||
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
||||
// from the outputs LiteGraph already restored (fallback, skip slot 0).
|
||||
const dynamicOutputs = this.outputs.slice(1);
|
||||
if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
|
||||
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
|
||||
}
|
||||
|
||||
this.setSize(this.computeSize());
|
||||
};
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user