Compare commits
27 Commits
nicegui-mi
...
feature/sq
| Author | SHA1 | Date | |
|---|---|---|---|
| 187b85b054 | |||
| a0d8cb8bbf | |||
| d55b3198e8 | |||
| bf2fca53e0 | |||
| 5b71d1b276 | |||
| 027ef8e78a | |||
| 86693f608a | |||
| 4b5fff5c6e | |||
| d07a308865 | |||
| c4d107206f | |||
| b499eb4dfd | |||
| ba8f104bc1 | |||
| 6b7e9ea682 | |||
| c15bec98ce | |||
| 0d8e84ea36 | |||
| e2f30b0332 | |||
| 24f9b7d955 | |||
| d56f6d8170 | |||
| f2980a9f94 | |||
| 4e3ff63f6a | |||
| 6e01cab5cd | |||
| 16ed81f0db | |||
| d98cee8015 | |||
| 2ebf3a4fcd | |||
| a4cb979131 | |||
| 9a3f7b7b94 | |||
| d8597f201a |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
__pycache__/
|
||||||
|
.pytest_cache/
|
||||||
|
.worktrees/
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
from .json_loader import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
from .project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = PROJECT_NODE_CLASS_MAPPINGS
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web"
|
WEB_DIRECTORY = "./web"
|
||||||
|
|
||||||
|
|||||||
80
api_routes.py
Normal file
80
api_routes.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""REST API endpoints for ComfyUI to query project data from SQLite.
|
||||||
|
|
||||||
|
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Query
|
||||||
|
from nicegui import app
|
||||||
|
|
||||||
|
from db import ProjectDB
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# The DB instance is set by register_api_routes()
|
||||||
|
_db: ProjectDB | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def register_api_routes(db: ProjectDB) -> None:
|
||||||
|
"""Register all REST API routes with the NiceGUI/FastAPI app."""
|
||||||
|
global _db
|
||||||
|
_db = db
|
||||||
|
|
||||||
|
app.add_api_route("/api/projects", _list_projects, methods=["GET"])
|
||||||
|
app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"])
|
||||||
|
app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"])
|
||||||
|
app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"])
|
||||||
|
app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_db() -> ProjectDB:
|
||||||
|
if _db is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||||
|
return _db
|
||||||
|
|
||||||
|
|
||||||
|
def _list_projects() -> dict[str, Any]:
|
||||||
|
db = _get_db()
|
||||||
|
projects = db.list_projects()
|
||||||
|
return {"projects": [p["name"] for p in projects]}
|
||||||
|
|
||||||
|
|
||||||
|
def _list_files(name: str) -> dict[str, Any]:
|
||||||
|
db = _get_db()
|
||||||
|
files = db.list_project_files(name)
|
||||||
|
return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]}
|
||||||
|
|
||||||
|
|
||||||
|
def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
||||||
|
db = _get_db()
|
||||||
|
seqs = db.list_project_sequences(name, file_name)
|
||||||
|
return {"sequences": seqs}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
|
db = _get_db()
|
||||||
|
proj = db.get_project(name)
|
||||||
|
if not proj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||||
|
df = db.get_data_file_by_names(name, file_name)
|
||||||
|
if not df:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
|
data = db.get_sequence(df["id"], seq)
|
||||||
|
if data is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
|
db = _get_db()
|
||||||
|
proj = db.get_project(name)
|
||||||
|
if not proj:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||||
|
df = db.get_data_file_by_names(name, file_name)
|
||||||
|
if not df:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
|
keys, types = db.get_sequence_keys(df["id"], seq)
|
||||||
|
total = db.count_sequences(df["id"])
|
||||||
|
return {"keys": keys, "types": types, "total_sequences": total}
|
||||||
224
app.py
224
app.py
@@ -1,224 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# --- Import Custom Modules ---
|
|
||||||
from utils import (
|
|
||||||
load_config, save_config, load_snippets, save_snippets,
|
|
||||||
load_json, save_json, generate_templates, DEFAULTS,
|
|
||||||
KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER,
|
|
||||||
resolve_path_case_insensitive,
|
|
||||||
)
|
|
||||||
from tab_batch import render_batch_processor
|
|
||||||
from tab_timeline import render_timeline_tab
|
|
||||||
from tab_comfy import render_comfy_monitor
|
|
||||||
from tab_raw import render_raw_editor
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 1. PAGE CONFIGURATION
|
|
||||||
# ==========================================
|
|
||||||
st.set_page_config(layout="wide", page_title="AI Settings Manager")
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 2. SESSION STATE INITIALIZATION
|
|
||||||
# ==========================================
|
|
||||||
_SESSION_DEFAULTS = {
|
|
||||||
"snippets": load_snippets,
|
|
||||||
"loaded_file": lambda: None,
|
|
||||||
"last_mtime": lambda: 0,
|
|
||||||
"ui_reset_token": lambda: 0,
|
|
||||||
"active_tab_name": lambda: "🚀 Batch Processor",
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'config' not in st.session_state:
|
|
||||||
st.session_state.config = load_config()
|
|
||||||
st.session_state.current_dir = Path(st.session_state.config.get("last_dir", Path.cwd()))
|
|
||||||
|
|
||||||
for key, factory in _SESSION_DEFAULTS.items():
|
|
||||||
if key not in st.session_state:
|
|
||||||
st.session_state[key] = factory()
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 3. SIDEBAR (NAVIGATOR & TOOLS)
|
|
||||||
# ==========================================
|
|
||||||
with st.sidebar:
|
|
||||||
st.header("📂 Navigator")
|
|
||||||
|
|
||||||
# --- Path Navigator ---
|
|
||||||
# Sync widget to current_dir on first load or after external change
|
|
||||||
if "nav_path_input" not in st.session_state or st.session_state.get("_sync_nav_path"):
|
|
||||||
st.session_state.nav_path_input = str(st.session_state.current_dir)
|
|
||||||
st.session_state._sync_nav_path = False
|
|
||||||
|
|
||||||
def _on_path_change():
|
|
||||||
new_path = st.session_state.nav_path_input
|
|
||||||
p = resolve_path_case_insensitive(new_path)
|
|
||||||
if p is not None and p.is_dir():
|
|
||||||
st.session_state.current_dir = p
|
|
||||||
st.session_state.config['last_dir'] = str(p)
|
|
||||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
|
||||||
st.session_state.loaded_file = None
|
|
||||||
# Always resync widget to canonical path form
|
|
||||||
st.session_state._sync_nav_path = True
|
|
||||||
|
|
||||||
st.text_input("Current Path", key="nav_path_input", on_change=_on_path_change)
|
|
||||||
|
|
||||||
# --- Favorites System ---
|
|
||||||
if st.button("📌 Pin Folder", use_container_width=True):
|
|
||||||
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
|
||||||
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
|
||||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
favorites = st.session_state.config['favorites']
|
|
||||||
if favorites:
|
|
||||||
def _on_fav_jump():
|
|
||||||
sel = st.session_state._fav_radio
|
|
||||||
if sel != "Select..." and sel != str(st.session_state.current_dir):
|
|
||||||
st.session_state.current_dir = Path(sel)
|
|
||||||
st.session_state._sync_nav_path = True
|
|
||||||
|
|
||||||
st.radio(
|
|
||||||
"Jump to:",
|
|
||||||
["Select..."] + favorites,
|
|
||||||
index=0,
|
|
||||||
key="_fav_radio",
|
|
||||||
label_visibility="collapsed",
|
|
||||||
on_change=_on_fav_jump,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Unpin buttons for each favorite
|
|
||||||
for fav in favorites:
|
|
||||||
fc1, fc2 = st.columns([4, 1])
|
|
||||||
fc1.caption(fav)
|
|
||||||
if fc2.button("❌", key=f"unpin_{fav}"):
|
|
||||||
st.session_state.config['favorites'].remove(fav)
|
|
||||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- Snippet Library ---
|
|
||||||
st.subheader("🧩 Snippet Library")
|
|
||||||
with st.expander("Add New Snippet"):
|
|
||||||
snip_name = st.text_input("Name", placeholder="e.g. Cinematic")
|
|
||||||
snip_content = st.text_area("Content", placeholder="4k, high quality...")
|
|
||||||
if st.button("Save Snippet"):
|
|
||||||
if snip_name and snip_content:
|
|
||||||
st.session_state.snippets[snip_name] = snip_content
|
|
||||||
save_snippets(st.session_state.snippets)
|
|
||||||
st.success(f"Saved '{snip_name}'")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
if st.session_state.snippets:
|
|
||||||
st.caption("Click to Append to Prompt:")
|
|
||||||
for name, content in st.session_state.snippets.items():
|
|
||||||
col_s1, col_s2 = st.columns([4, 1])
|
|
||||||
if col_s1.button(f"➕ {name}", use_container_width=True):
|
|
||||||
st.rerun()
|
|
||||||
if col_s2.button("🗑️", key=f"del_snip_{name}"):
|
|
||||||
del st.session_state.snippets[name]
|
|
||||||
save_snippets(st.session_state.snippets)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- File List & Creation ---
|
|
||||||
json_files = sorted(list(st.session_state.current_dir.glob("*.json")))
|
|
||||||
json_files = [f for f in json_files if f.name != ".editor_config.json" and f.name != ".editor_snippets.json"]
|
|
||||||
|
|
||||||
if not json_files:
|
|
||||||
if st.button("Generate Templates"):
|
|
||||||
generate_templates(st.session_state.current_dir)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
with st.expander("Create New JSON"):
|
|
||||||
new_filename = st.text_input("Filename", placeholder="my_prompt_vace")
|
|
||||||
if st.button("Create"):
|
|
||||||
if not new_filename.endswith(".json"): new_filename += ".json"
|
|
||||||
path = st.session_state.current_dir / new_filename
|
|
||||||
first_item = DEFAULTS.copy()
|
|
||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
data = {KEY_BATCH_DATA: [first_item]}
|
|
||||||
save_json(path, data)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- File Selector ---
|
|
||||||
selected_file_name = None
|
|
||||||
if json_files:
|
|
||||||
file_names = [f.name for f in json_files]
|
|
||||||
if 'file_selector' not in st.session_state:
|
|
||||||
st.session_state.file_selector = file_names[0]
|
|
||||||
if st.session_state.file_selector not in file_names:
|
|
||||||
st.session_state.file_selector = file_names[0]
|
|
||||||
|
|
||||||
selected_file_name = st.radio("Select File", file_names, key="file_selector")
|
|
||||||
else:
|
|
||||||
st.info("No JSON files in this folder.")
|
|
||||||
if 'file_selector' in st.session_state:
|
|
||||||
del st.session_state.file_selector
|
|
||||||
st.session_state.loaded_file = None
|
|
||||||
|
|
||||||
# --- GLOBAL MONITOR TOGGLE (NEW) ---
|
|
||||||
st.markdown("---")
|
|
||||||
show_monitor = st.checkbox("Show Comfy Monitor", value=True)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 4. MAIN APP LOGIC
|
|
||||||
# ==========================================
|
|
||||||
if selected_file_name:
|
|
||||||
file_path = st.session_state.current_dir / selected_file_name
|
|
||||||
|
|
||||||
# --- FILE LOADING & AUTO-SWITCH LOGIC ---
|
|
||||||
if st.session_state.loaded_file != str(file_path):
|
|
||||||
data, mtime = load_json(file_path)
|
|
||||||
st.session_state.data_cache = data
|
|
||||||
st.session_state.last_mtime = mtime
|
|
||||||
st.session_state.loaded_file = str(file_path)
|
|
||||||
|
|
||||||
# Clear transient states
|
|
||||||
if 'restored_indicator' in st.session_state: del st.session_state.restored_indicator
|
|
||||||
|
|
||||||
# --- AUTO-SWITCH TAB LOGIC ---
|
|
||||||
st.session_state.active_tab_name = "🚀 Batch Processor"
|
|
||||||
|
|
||||||
else:
|
|
||||||
data = st.session_state.data_cache
|
|
||||||
|
|
||||||
st.title(f"Editing: {selected_file_name}")
|
|
||||||
|
|
||||||
# --- CONTROLLED NAVIGATION ---
|
|
||||||
# Removed "🔌 Comfy Monitor" from this list
|
|
||||||
tabs_list = [
|
|
||||||
"🚀 Batch Processor",
|
|
||||||
"🕒 Timeline",
|
|
||||||
"💻 Raw Editor"
|
|
||||||
]
|
|
||||||
|
|
||||||
if st.session_state.active_tab_name not in tabs_list:
|
|
||||||
st.session_state.active_tab_name = tabs_list[0]
|
|
||||||
|
|
||||||
current_tab = st.radio(
|
|
||||||
"Navigation",
|
|
||||||
tabs_list,
|
|
||||||
key="active_tab_name",
|
|
||||||
horizontal=True,
|
|
||||||
label_visibility="collapsed"
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- RENDER EDITOR TABS ---
|
|
||||||
if current_tab == "🚀 Batch Processor":
|
|
||||||
render_batch_processor(data, file_path, json_files, st.session_state.current_dir, selected_file_name)
|
|
||||||
|
|
||||||
elif current_tab == "🕒 Timeline":
|
|
||||||
render_timeline_tab(data, file_path)
|
|
||||||
|
|
||||||
elif current_tab == "💻 Raw Editor":
|
|
||||||
render_raw_editor(data, file_path)
|
|
||||||
|
|
||||||
# --- GLOBAL PERSISTENT MONITOR ---
|
|
||||||
if show_monitor:
|
|
||||||
st.markdown("---")
|
|
||||||
with st.expander("🔌 ComfyUI Monitor", expanded=True):
|
|
||||||
render_comfy_monitor()
|
|
||||||
349
db.py
Normal file
349
db.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db"
|
||||||
|
|
||||||
|
SCHEMA_SQL = """
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
folder_path TEXT NOT NULL,
|
||||||
|
description TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
updated_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS data_files (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
data_type TEXT NOT NULL DEFAULT 'generic',
|
||||||
|
top_level TEXT NOT NULL DEFAULT '{}',
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
updated_at REAL NOT NULL,
|
||||||
|
UNIQUE(project_id, name)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS sequences (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
||||||
|
sequence_number INTEGER NOT NULL,
|
||||||
|
data TEXT NOT NULL DEFAULT '{}',
|
||||||
|
updated_at REAL NOT NULL,
|
||||||
|
UNIQUE(data_file_id, sequence_number)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS history_trees (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
|
||||||
|
tree_data TEXT NOT NULL DEFAULT '{}',
|
||||||
|
updated_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectDB:
|
||||||
|
"""SQLite database for project-based data management."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path | None = None):
|
||||||
|
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.conn = sqlite3.connect(
|
||||||
|
str(self.db_path),
|
||||||
|
check_same_thread=False,
|
||||||
|
isolation_level=None, # autocommit — explicit BEGIN/COMMIT only
|
||||||
|
)
|
||||||
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self.conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
self.conn.executescript(SCHEMA_SQL)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Projects CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
|
||||||
|
now = time.time()
|
||||||
|
cur = self.conn.execute(
|
||||||
|
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(name, folder_path, description, now, now),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
def list_projects(self) -> list[dict]:
|
||||||
|
rows = self.conn.execute(
|
||||||
|
"SELECT id, name, folder_path, description, created_at, updated_at "
|
||||||
|
"FROM projects ORDER BY name"
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
def get_project(self, name: str) -> dict | None:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT id, name, folder_path, description, created_at, updated_at "
|
||||||
|
"FROM projects WHERE name = ?",
|
||||||
|
(name,),
|
||||||
|
).fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
def delete_project(self, name: str) -> bool:
|
||||||
|
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
||||||
|
self.conn.commit()
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Data files
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create_data_file(
|
||||||
|
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
|
||||||
|
) -> int:
|
||||||
|
now = time.time()
|
||||||
|
tl = json.dumps(top_level or {})
|
||||||
|
cur = self.conn.execute(
|
||||||
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(project_id, name, data_type, tl, now, now),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
return cur.lastrowid
|
||||||
|
|
||||||
|
def list_data_files(self, project_id: int) -> list[dict]:
|
||||||
|
rows = self.conn.execute(
|
||||||
|
"SELECT id, project_id, name, data_type, created_at, updated_at "
|
||||||
|
"FROM data_files WHERE project_id = ? ORDER BY name",
|
||||||
|
(project_id,),
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
||||||
|
"FROM data_files WHERE project_id = ? AND name = ?",
|
||||||
|
(project_id, name),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
d = dict(row)
|
||||||
|
d["top_level"] = json.loads(d["top_level"])
|
||||||
|
return d
|
||||||
|
|
||||||
|
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
|
||||||
|
"df.created_at, df.updated_at "
|
||||||
|
"FROM data_files df JOIN projects p ON df.project_id = p.id "
|
||||||
|
"WHERE p.name = ? AND df.name = ?",
|
||||||
|
(project_name, file_name),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
d = dict(row)
|
||||||
|
d["top_level"] = json.loads(d["top_level"])
|
||||||
|
return d
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Sequences
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
||||||
|
(data_file_id, sequence_number, json.dumps(data), now),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
||||||
|
(data_file_id, sequence_number),
|
||||||
|
).fetchone()
|
||||||
|
return json.loads(row["data"]) if row else None
|
||||||
|
|
||||||
|
def list_sequences(self, data_file_id: int) -> list[int]:
|
||||||
|
rows = self.conn.execute(
|
||||||
|
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
|
||||||
|
(data_file_id,),
|
||||||
|
).fetchall()
|
||||||
|
return [r["sequence_number"] for r in rows]
|
||||||
|
|
||||||
|
def count_sequences(self, data_file_id: int) -> int:
|
||||||
|
"""Return the number of sequences for a data file."""
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
||||||
|
(data_file_id,),
|
||||||
|
).fetchone()
|
||||||
|
return row["cnt"]
|
||||||
|
|
||||||
|
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
||||||
|
"""Return total sequence count by project and file names."""
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return 0
|
||||||
|
return self.count_sequences(df["id"])
|
||||||
|
|
||||||
|
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||||
|
"""Returns (keys, types) for a sequence's data dict."""
|
||||||
|
data = self.get_sequence(data_file_id, sequence_number)
|
||||||
|
if not data:
|
||||||
|
return [], []
|
||||||
|
keys = []
|
||||||
|
types = []
|
||||||
|
for k, v in data.items():
|
||||||
|
keys.append(k)
|
||||||
|
if isinstance(v, bool):
|
||||||
|
types.append("STRING")
|
||||||
|
elif isinstance(v, int):
|
||||||
|
types.append("INT")
|
||||||
|
elif isinstance(v, float):
|
||||||
|
types.append("FLOAT")
|
||||||
|
else:
|
||||||
|
types.append("STRING")
|
||||||
|
return keys, types
|
||||||
|
|
||||||
|
def delete_sequences_for_file(self, data_file_id: int) -> None:
|
||||||
|
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# History trees
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
|
(data_file_id, json.dumps(tree_data), now),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_history_tree(self, data_file_id: int) -> dict | None:
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
||||||
|
(data_file_id,),
|
||||||
|
).fetchone()
|
||||||
|
return json.loads(row["tree_data"]) if row else None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Import
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
||||||
|
"""Import a JSON file into the database, splitting batch_data into sequences.
|
||||||
|
|
||||||
|
Safe to call repeatedly — existing data_file is updated, sequences are
|
||||||
|
replaced, and history_tree is upserted. Atomic: all-or-nothing.
|
||||||
|
"""
|
||||||
|
json_path = Path(json_path)
|
||||||
|
data, _ = load_json(json_path)
|
||||||
|
file_name = json_path.stem
|
||||||
|
|
||||||
|
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||||
|
|
||||||
|
self.conn.execute("BEGIN IMMEDIATE")
|
||||||
|
try:
|
||||||
|
existing = self.conn.execute(
|
||||||
|
"SELECT id FROM data_files WHERE project_id = ? AND name = ?",
|
||||||
|
(project_id, file_name),
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
df_id = existing["id"]
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(data_type, json.dumps(top_level), now, df_id),
|
||||||
|
)
|
||||||
|
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||||
|
else:
|
||||||
|
now = time.time()
|
||||||
|
cur = self.conn.execute(
|
||||||
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(project_id, file_name, data_type, json.dumps(top_level), now, now),
|
||||||
|
)
|
||||||
|
df_id = cur.lastrowid
|
||||||
|
|
||||||
|
# Import sequences from batch_data
|
||||||
|
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||||
|
if isinstance(batch_data, list):
|
||||||
|
for item in batch_data:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
seq_num = int(item.get("sequence_number", 0))
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, seq_num, json.dumps(item), now),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import history tree
|
||||||
|
history_tree = data.get(KEY_HISTORY_TREE)
|
||||||
|
if history_tree and isinstance(history_tree, dict):
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, json.dumps(history_tree), now),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conn.execute("COMMIT")
|
||||||
|
return df_id
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
self.conn.execute("ROLLBACK")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Query helpers (for REST API)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
|
||||||
|
"""Query a single sequence by project name, file name, and sequence number."""
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return None
|
||||||
|
return self.get_sequence(df["id"], sequence_number)
|
||||||
|
|
||||||
|
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
|
||||||
|
"""Query keys and types for a sequence."""
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return [], []
|
||||||
|
return self.get_sequence_keys(df["id"], sequence_number)
|
||||||
|
|
||||||
|
def list_project_files(self, project_name: str) -> list[dict]:
|
||||||
|
"""List data files for a project by name."""
|
||||||
|
proj = self.get_project(project_name)
|
||||||
|
if not proj:
|
||||||
|
return []
|
||||||
|
return self.list_data_files(proj["id"])
|
||||||
|
|
||||||
|
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
|
||||||
|
"""List sequence numbers for a file in a project."""
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return []
|
||||||
|
return self.list_sequences(df["id"])
|
||||||
@@ -104,13 +104,37 @@ class HistoryTree:
|
|||||||
'digraph History {',
|
'digraph History {',
|
||||||
f' rankdir={direction};',
|
f' rankdir={direction};',
|
||||||
' bgcolor="white";',
|
' bgcolor="white";',
|
||||||
' splines=ortho;',
|
' splines=polyline;',
|
||||||
f' nodesep={nodesep};',
|
f' nodesep={nodesep};',
|
||||||
f' ranksep={ranksep};',
|
f' ranksep={ranksep};',
|
||||||
' node [shape=plain, fontname="Arial"];',
|
' node [shape=plain, fontname="Arial"];',
|
||||||
' edge [color="#888888", arrowsize=0.6, penwidth=1.0];'
|
' edge [color="#888888", arrowsize=0.6, penwidth=1.0];'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Build reverse lookup: node_id -> branch name (walk each branch ancestry)
|
||||||
|
node_to_branch: dict[str, str] = {}
|
||||||
|
for b_name, tip_id in self.branches.items():
|
||||||
|
current = tip_id
|
||||||
|
while current and current in self.nodes:
|
||||||
|
if current not in node_to_branch:
|
||||||
|
node_to_branch[current] = b_name
|
||||||
|
current = self.nodes[current].get('parent')
|
||||||
|
|
||||||
|
# Per-branch color palette (bg, border) — cycles for many branches
|
||||||
|
_branch_palette = [
|
||||||
|
('#f9f9f9', '#999999'), # grey (default/main)
|
||||||
|
('#eef4ff', '#6699cc'), # blue
|
||||||
|
('#f5eeff', '#9977cc'), # purple
|
||||||
|
('#fff0ee', '#cc7766'), # coral
|
||||||
|
('#eefff5', '#66aa88'), # teal
|
||||||
|
('#fff8ee', '#ccaa55'), # sand
|
||||||
|
]
|
||||||
|
branch_names = list(self.branches.keys())
|
||||||
|
branch_colors = {
|
||||||
|
b: _branch_palette[i % len(_branch_palette)]
|
||||||
|
for i, b in enumerate(branch_names)
|
||||||
|
}
|
||||||
|
|
||||||
sorted_nodes = sorted(self.nodes.values(), key=lambda x: x["timestamp"])
|
sorted_nodes = sorted(self.nodes.values(), key=lambda x: x["timestamp"])
|
||||||
|
|
||||||
# Font sizes and padding - smaller for vertical
|
# Font sizes and padding - smaller for vertical
|
||||||
@@ -138,9 +162,10 @@ class HistoryTree:
|
|||||||
if nid in tip_to_branches:
|
if nid in tip_to_branches:
|
||||||
branch_label = ", ".join(tip_to_branches[nid])
|
branch_label = ", ".join(tip_to_branches[nid])
|
||||||
|
|
||||||
# COLORS
|
# COLORS — per-branch tint, overridden for HEAD and tips
|
||||||
bg_color = "#f9f9f9"
|
b_name = node_to_branch.get(nid)
|
||||||
border_color = "#999999"
|
bg_color, border_color = branch_colors.get(
|
||||||
|
b_name, _branch_palette[0])
|
||||||
border_width = "1"
|
border_width = "1"
|
||||||
|
|
||||||
if nid == self.head_id:
|
if nid == self.head_id:
|
||||||
|
|||||||
384
json_loader.py
384
json_loader.py
@@ -1,384 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
KEY_BATCH_DATA = "batch_data"
|
|
||||||
MAX_DYNAMIC_OUTPUTS = 32
|
|
||||||
|
|
||||||
|
|
||||||
class AnyType(str):
|
|
||||||
"""Universal connector type that matches any ComfyUI type."""
|
|
||||||
def __ne__(self, __value: object) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
any_type = AnyType("*")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from server import PromptServer
|
|
||||||
from aiohttp import web
|
|
||||||
except ImportError:
|
|
||||||
PromptServer = None
|
|
||||||
|
|
||||||
|
|
||||||
def to_float(val: Any) -> float:
|
|
||||||
try:
|
|
||||||
return float(val)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
def to_int(val: Any) -> int:
|
|
||||||
try:
|
|
||||||
return int(float(val))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]:
|
|
||||||
"""Resolve batch item by sequence_number field, falling back to array index."""
|
|
||||||
if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0:
|
|
||||||
# Search by sequence_number field first
|
|
||||||
for item in data[KEY_BATCH_DATA]:
|
|
||||||
if int(item.get("sequence_number", 0)) == sequence_number:
|
|
||||||
return item
|
|
||||||
# Fallback to array index
|
|
||||||
idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1))
|
|
||||||
logger.warning(f"No item with sequence_number={sequence_number}, falling back to index {idx}")
|
|
||||||
return data[KEY_BATCH_DATA][idx]
|
|
||||||
return data
|
|
||||||
|
|
||||||
# --- Shared Helper ---
|
|
||||||
def read_json_data(json_path: str) -> dict[str, Any]:
|
|
||||||
if not os.path.exists(json_path):
|
|
||||||
logger.warning(f"File not found at {json_path}")
|
|
||||||
return {}
|
|
||||||
try:
|
|
||||||
with open(json_path, 'r') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
except (json.JSONDecodeError, IOError) as e:
|
|
||||||
logger.warning(f"Error reading {json_path}: {e}")
|
|
||||||
return {}
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
logger.warning(f"Expected dict from {json_path}, got {type(data).__name__}")
|
|
||||||
return {}
|
|
||||||
return data
|
|
||||||
|
|
||||||
# --- API Route ---
|
|
||||||
if PromptServer is not None:
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/get_keys")
|
|
||||||
async def get_keys_route(request):
|
|
||||||
json_path = request.query.get("path", "")
|
|
||||||
try:
|
|
||||||
seq = int(request.query.get("sequence_number", "1"))
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
seq = 1
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target = get_batch_item(data, seq)
|
|
||||||
keys = []
|
|
||||||
types = []
|
|
||||||
if isinstance(target, dict):
|
|
||||||
for k, v in target.items():
|
|
||||||
keys.append(k)
|
|
||||||
if isinstance(v, bool):
|
|
||||||
types.append("STRING")
|
|
||||||
elif isinstance(v, int):
|
|
||||||
types.append("INT")
|
|
||||||
elif isinstance(v, float):
|
|
||||||
types.append("FLOAT")
|
|
||||||
else:
|
|
||||||
types.append("STRING")
|
|
||||||
return web.json_response({"keys": keys, "types": types})
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 0. DYNAMIC NODE
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class JSONLoaderDynamic:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"json_path": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"output_keys": ("STRING", {"default": ""}),
|
|
||||||
"output_types": ("STRING", {"default": ""}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
|
||||||
RETURN_NAMES = tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
|
||||||
FUNCTION = "load_dynamic"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
OUTPUT_NODE = False
|
|
||||||
|
|
||||||
def load_dynamic(self, json_path, sequence_number, output_keys="", output_types=""):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target = get_batch_item(data, sequence_number)
|
|
||||||
|
|
||||||
keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else []
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for key in keys:
|
|
||||||
val = target.get(key, "")
|
|
||||||
if isinstance(val, bool):
|
|
||||||
results.append(str(val).lower())
|
|
||||||
elif isinstance(val, int):
|
|
||||||
results.append(val)
|
|
||||||
elif isinstance(val, float):
|
|
||||||
results.append(val)
|
|
||||||
else:
|
|
||||||
results.append(str(val))
|
|
||||||
|
|
||||||
# Pad to MAX_DYNAMIC_OUTPUTS
|
|
||||||
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
|
||||||
results.append("")
|
|
||||||
|
|
||||||
return tuple(results)
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 1. STANDARD NODES (Single File)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class JSONLoaderLoRA:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
|
||||||
FUNCTION = "load_loras"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_loras(self, json_path):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
return (
|
|
||||||
str(data.get("lora 1 high", "")), str(data.get("lora 1 low", "")),
|
|
||||||
str(data.get("lora 2 high", "")), str(data.get("lora 2 low", "")),
|
|
||||||
str(data.get("lora 3 high", "")), str(data.get("lora 3 low", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONLoaderStandard:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
|
||||||
FUNCTION = "load_standard"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_standard(self, json_path):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
return (
|
|
||||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
|
||||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
|
||||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
|
||||||
to_int(data.get("seed", 0)), str(data.get("video file path", "")),
|
|
||||||
str(data.get("reference image path", "")), str(data.get("flf image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONLoaderVACE:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False})}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
|
||||||
FUNCTION = "load_vace"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_vace(self, json_path):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
return (
|
|
||||||
str(data.get("general_prompt", "")), str(data.get("general_negative", "")),
|
|
||||||
str(data.get("current_prompt", "")), str(data.get("negative", "")),
|
|
||||||
str(data.get("camera", "")), to_float(data.get("flf", 0.0)),
|
|
||||||
to_int(data.get("seed", 0)),
|
|
||||||
to_int(data.get("frame_to_skip", 81)), to_int(data.get("input_a_frames", 16)),
|
|
||||||
to_int(data.get("input_b_frames", 16)), str(data.get("reference path", "")),
|
|
||||||
to_int(data.get("reference switch", 1)), to_int(data.get("vace schedule", 1)),
|
|
||||||
str(data.get("video file path", "")), str(data.get("reference image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 2. BATCH NODES
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class JSONLoaderBatchLoRA:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("lora_1_high", "lora_1_low", "lora_2_high", "lora_2_low", "lora_3_high", "lora_3_low")
|
|
||||||
FUNCTION = "load_batch_loras"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_batch_loras(self, json_path, sequence_number):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
return (
|
|
||||||
str(target_data.get("lora 1 high", "")), str(target_data.get("lora 1 low", "")),
|
|
||||||
str(target_data.get("lora 2 high", "")), str(target_data.get("lora 2 low", "")),
|
|
||||||
str(target_data.get("lora 3 high", "")), str(target_data.get("lora 3 low", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONLoaderBatchI2V:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "video_file_path", "reference_image_path", "flf_image_path")
|
|
||||||
FUNCTION = "load_batch_i2v"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_batch_i2v(self, json_path, sequence_number):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
|
|
||||||
return (
|
|
||||||
str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")),
|
|
||||||
str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")),
|
|
||||||
str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)),
|
|
||||||
to_int(target_data.get("seed", 0)), str(target_data.get("video file path", "")),
|
|
||||||
str(target_data.get("reference image path", "")), str(target_data.get("flf image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONLoaderBatchVACE:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {"required": {"json_path": ("STRING", {"default": "", "multiline": False}), "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999})}}
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "FLOAT", "INT", "INT", "INT", "INT", "STRING", "INT", "INT", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("general_prompt", "general_negative", "current_prompt", "negative", "camera", "flf", "seed", "frame_to_skip", "input_a_frames", "input_b_frames", "reference_path", "reference_switch", "vace_schedule", "video_file_path", "reference_image_path")
|
|
||||||
FUNCTION = "load_batch_vace"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_batch_vace(self, json_path, sequence_number):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
|
|
||||||
return (
|
|
||||||
str(target_data.get("general_prompt", "")), str(target_data.get("general_negative", "")),
|
|
||||||
str(target_data.get("current_prompt", "")), str(target_data.get("negative", "")),
|
|
||||||
str(target_data.get("camera", "")), to_float(target_data.get("flf", 0.0)),
|
|
||||||
to_int(target_data.get("seed", 0)), to_int(target_data.get("frame_to_skip", 81)),
|
|
||||||
to_int(target_data.get("input_a_frames", 16)), to_int(target_data.get("input_b_frames", 16)),
|
|
||||||
str(target_data.get("reference path", "")), to_int(target_data.get("reference switch", 1)),
|
|
||||||
to_int(target_data.get("vace schedule", 1)), str(target_data.get("video file path", "")),
|
|
||||||
str(target_data.get("reference image path", ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
# ==========================================
|
|
||||||
# 3. UNIVERSAL CUSTOM NODES (1, 3, 6 Slots)
|
|
||||||
# ==========================================
|
|
||||||
|
|
||||||
class JSONLoaderCustom1:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"json_path": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": { "key_1": ("STRING", {"default": "", "multiline": False}) }
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("val_1",)
|
|
||||||
FUNCTION = "load_custom"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_custom(self, json_path, sequence_number, key_1=""):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
return (str(target_data.get(key_1, "")),)
|
|
||||||
|
|
||||||
class JSONLoaderCustom3:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"json_path": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"key_1": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_2": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_3": ("STRING", {"default": "", "multiline": False})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("val_1", "val_2", "val_3")
|
|
||||||
FUNCTION = "load_custom"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3=""):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
return (
|
|
||||||
str(target_data.get(key_1, "")),
|
|
||||||
str(target_data.get(key_2, "")),
|
|
||||||
str(target_data.get(key_3, ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
class JSONLoaderCustom6:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(s):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"json_path": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"key_1": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_2": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_3": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_4": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_5": ("STRING", {"default": "", "multiline": False}),
|
|
||||||
"key_6": ("STRING", {"default": "", "multiline": False})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING")
|
|
||||||
RETURN_NAMES = ("val_1", "val_2", "val_3", "val_4", "val_5", "val_6")
|
|
||||||
FUNCTION = "load_custom"
|
|
||||||
CATEGORY = "utils/json"
|
|
||||||
|
|
||||||
def load_custom(self, json_path, sequence_number, key_1="", key_2="", key_3="", key_4="", key_5="", key_6=""):
|
|
||||||
data = read_json_data(json_path)
|
|
||||||
target_data = get_batch_item(data, sequence_number)
|
|
||||||
return (
|
|
||||||
str(target_data.get(key_1, "")), str(target_data.get(key_2, "")),
|
|
||||||
str(target_data.get(key_3, "")), str(target_data.get(key_4, "")),
|
|
||||||
str(target_data.get(key_5, "")), str(target_data.get(key_6, ""))
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Mappings ---
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"JSONLoaderDynamic": JSONLoaderDynamic,
|
|
||||||
"JSONLoaderLoRA": JSONLoaderLoRA,
|
|
||||||
"JSONLoaderStandard": JSONLoaderStandard,
|
|
||||||
"JSONLoaderVACE": JSONLoaderVACE,
|
|
||||||
"JSONLoaderBatchLoRA": JSONLoaderBatchLoRA,
|
|
||||||
"JSONLoaderBatchI2V": JSONLoaderBatchI2V,
|
|
||||||
"JSONLoaderBatchVACE": JSONLoaderBatchVACE,
|
|
||||||
"JSONLoaderCustom1": JSONLoaderCustom1,
|
|
||||||
"JSONLoaderCustom3": JSONLoaderCustom3,
|
|
||||||
"JSONLoaderCustom6": JSONLoaderCustom6
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"JSONLoaderDynamic": "JSON Loader (Dynamic)",
|
|
||||||
"JSONLoaderLoRA": "JSON Loader (LoRAs Only)",
|
|
||||||
"JSONLoaderStandard": "JSON Loader (Standard/I2V)",
|
|
||||||
"JSONLoaderVACE": "JSON Loader (VACE Full)",
|
|
||||||
"JSONLoaderBatchLoRA": "JSON Batch Loader (LoRAs)",
|
|
||||||
"JSONLoaderBatchI2V": "JSON Batch Loader (I2V)",
|
|
||||||
"JSONLoaderBatchVACE": "JSON Batch Loader (VACE)",
|
|
||||||
"JSONLoaderCustom1": "JSON Loader (Custom 1)",
|
|
||||||
"JSONLoaderCustom3": "JSON Loader (Custom 3)",
|
|
||||||
"JSONLoaderCustom6": "JSON Loader (Custom 6)"
|
|
||||||
}
|
|
||||||
27
main.py
27
main.py
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from nicegui import ui
|
from nicegui import ui
|
||||||
@@ -14,11 +15,22 @@ from tab_batch_ng import render_batch_processor
|
|||||||
from tab_timeline_ng import render_timeline_tab
|
from tab_timeline_ng import render_timeline_tab
|
||||||
from tab_raw_ng import render_raw_editor
|
from tab_raw_ng import render_raw_editor
|
||||||
from tab_comfy_ng import render_comfy_monitor
|
from tab_comfy_ng import render_comfy_monitor
|
||||||
|
from tab_projects_ng import render_projects_tab
|
||||||
|
from db import ProjectDB
|
||||||
|
from api_routes import register_api_routes
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Single shared DB instance for both the UI and API routes
|
||||||
|
_shared_db: ProjectDB | None = None
|
||||||
|
try:
|
||||||
|
_shared_db = ProjectDB()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
||||||
|
|
||||||
|
|
||||||
@ui.page('/')
|
@ui.page('/')
|
||||||
def index():
|
def index():
|
||||||
# -- Streamlit dark theme --
|
|
||||||
ui.dark_mode(True)
|
ui.dark_mode(True)
|
||||||
ui.colors(primary='#F59E0B')
|
ui.colors(primary='#F59E0B')
|
||||||
ui.add_head_html(
|
ui.add_head_html(
|
||||||
@@ -157,7 +169,13 @@ def index():
|
|||||||
config=config,
|
config=config,
|
||||||
current_dir=Path(config.get('last_dir', str(Path.cwd()))),
|
current_dir=Path(config.get('last_dir', str(Path.cwd()))),
|
||||||
snippets=load_snippets(),
|
snippets=load_snippets(),
|
||||||
|
db_enabled=config.get('db_enabled', False),
|
||||||
|
current_project=config.get('current_project', ''),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use the shared DB instance
|
||||||
|
state.db = _shared_db
|
||||||
|
|
||||||
dual_pane = {'active': False, 'state': None}
|
dual_pane = {'active': False, 'state': None}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -179,6 +197,7 @@ def index():
|
|||||||
ui.tab('batch', label='Batch Processor')
|
ui.tab('batch', label='Batch Processor')
|
||||||
ui.tab('timeline', label='Timeline')
|
ui.tab('timeline', label='Timeline')
|
||||||
ui.tab('raw', label='Raw Editor')
|
ui.tab('raw', label='Raw Editor')
|
||||||
|
ui.tab('projects', label='Projects')
|
||||||
|
|
||||||
with ui.tab_panels(tabs, value='batch').classes('w-full'):
|
with ui.tab_panels(tabs, value='batch').classes('w-full'):
|
||||||
with ui.tab_panel('batch'):
|
with ui.tab_panel('batch'):
|
||||||
@@ -187,6 +206,8 @@ def index():
|
|||||||
render_timeline_tab(state)
|
render_timeline_tab(state)
|
||||||
with ui.tab_panel('raw'):
|
with ui.tab_panel('raw'):
|
||||||
render_raw_editor(state)
|
render_raw_editor(state)
|
||||||
|
with ui.tab_panel('projects'):
|
||||||
|
render_projects_tab(state)
|
||||||
|
|
||||||
if state.show_comfy_monitor:
|
if state.show_comfy_monitor:
|
||||||
ui.separator()
|
ui.separator()
|
||||||
@@ -482,4 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
|||||||
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
||||||
|
|
||||||
|
|
||||||
|
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
||||||
|
if _shared_db is not None:
|
||||||
|
register_api_routes(_shared_db)
|
||||||
|
|
||||||
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
||||||
|
|||||||
215
project_loader.py
Normal file
215
project_loader.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_DYNAMIC_OUTPUTS = 32
|
||||||
|
|
||||||
|
|
||||||
|
class AnyType(str):
|
||||||
|
"""Universal connector type that matches any ComfyUI type."""
|
||||||
|
def __ne__(self, __value: object) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
any_type = AnyType("*")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from server import PromptServer
|
||||||
|
from aiohttp import web
|
||||||
|
except ImportError:
|
||||||
|
PromptServer = None
|
||||||
|
|
||||||
|
|
||||||
|
def to_float(val: Any) -> float:
|
||||||
|
try:
|
||||||
|
return float(val)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def to_int(val: Any) -> int:
|
||||||
|
try:
|
||||||
|
return int(float(val))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_json(url: str) -> dict:
|
||||||
|
"""Fetch JSON from a URL using stdlib urllib.
|
||||||
|
|
||||||
|
On error, returns a dict with an "error" key describing the failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(url, timeout=5) as resp:
|
||||||
|
return json.loads(resp.read())
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
# HTTPError is a subclass of URLError — must be caught first
|
||||||
|
body = ""
|
||||||
|
try:
|
||||||
|
raw = e.read()
|
||||||
|
detail = json.loads(raw)
|
||||||
|
body = detail.get("detail", str(raw, "utf-8", errors="replace"))
|
||||||
|
except Exception:
|
||||||
|
body = str(e)
|
||||||
|
logger.warning(f"HTTP {e.code} from {url}: {body}")
|
||||||
|
return {"error": "http_error", "status": e.code, "message": body}
|
||||||
|
except (urllib.error.URLError, OSError) as e:
|
||||||
|
reason = str(e.reason) if hasattr(e, "reason") else str(e)
|
||||||
|
logger.warning(f"Network error fetching {url}: {reason}")
|
||||||
|
return {"error": "network_error", "message": reason}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Invalid JSON from {url}: {e}")
|
||||||
|
return {"error": "parse_error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||||
|
"""Fetch sequence data from the NiceGUI REST API."""
|
||||||
|
p = urllib.parse.quote(project, safe='')
|
||||||
|
f = urllib.parse.quote(file, safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
|
||||||
|
return _fetch_json(url)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||||
|
"""Fetch keys/types from the NiceGUI REST API."""
|
||||||
|
p = urllib.parse.quote(project, safe='')
|
||||||
|
f = urllib.parse.quote(file, safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
|
||||||
|
return _fetch_json(url)
|
||||||
|
|
||||||
|
|
||||||
|
# --- ComfyUI-side proxy endpoints (for frontend JS) ---
|
||||||
|
if PromptServer is not None:
|
||||||
|
@PromptServer.instance.routes.get("/json_manager/list_projects")
|
||||||
|
async def list_projects_proxy(request):
|
||||||
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects"
|
||||||
|
data = _fetch_json(url)
|
||||||
|
return web.json_response(data)
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
||||||
|
async def list_project_files_proxy(request):
|
||||||
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
||||||
|
data = _fetch_json(url)
|
||||||
|
return web.json_response(data)
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
||||||
|
async def list_project_sequences_proxy(request):
|
||||||
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
|
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
||||||
|
data = _fetch_json(url)
|
||||||
|
return web.json_response(data)
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
|
||||||
|
async def get_project_keys_proxy(request):
|
||||||
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
|
project = request.query.get("project", "")
|
||||||
|
file_name = request.query.get("file", "")
|
||||||
|
try:
|
||||||
|
seq = int(request.query.get("seq", "1"))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
seq = 1
|
||||||
|
data = _fetch_keys(manager_url, project, file_name, seq)
|
||||||
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
status = data.get("status", 502)
|
||||||
|
return web.json_response(data, status=status)
|
||||||
|
return web.json_response(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 0. DYNAMIC NODE (Project-based)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
class ProjectLoaderDynamic:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||||
|
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"output_keys": ("STRING", {"default": ""}),
|
||||||
|
"output_types": ("STRING", {"default": ""}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS))
|
||||||
|
RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS))
|
||||||
|
FUNCTION = "load_dynamic"
|
||||||
|
CATEGORY = "utils/json/project"
|
||||||
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
|
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
||||||
|
output_keys="", output_types=""):
|
||||||
|
# Fetch keys metadata (includes total_sequences count)
|
||||||
|
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
||||||
|
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
msg = keys_meta.get("message", "Unknown error")
|
||||||
|
raise RuntimeError(f"Failed to fetch project keys: {msg}")
|
||||||
|
total_sequences = keys_meta.get("total_sequences", 0)
|
||||||
|
|
||||||
|
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||||
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
msg = data.get("message", "Unknown error")
|
||||||
|
raise RuntimeError(f"Failed to fetch sequence data: {msg}")
|
||||||
|
|
||||||
|
# Parse keys — try JSON array first, fall back to comma-split for compat
|
||||||
|
keys = []
|
||||||
|
if output_keys:
|
||||||
|
try:
|
||||||
|
keys = json.loads(output_keys)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
keys = [k.strip() for k in output_keys.split(",") if k.strip()]
|
||||||
|
|
||||||
|
# Parse types for coercion
|
||||||
|
types = []
|
||||||
|
if output_types:
|
||||||
|
try:
|
||||||
|
types = json.loads(output_types)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
types = [t.strip() for t in output_types.split(",")]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i, key in enumerate(keys):
|
||||||
|
val = data.get(key, "")
|
||||||
|
declared_type = types[i] if i < len(types) else ""
|
||||||
|
# Coerce based on declared output type when possible
|
||||||
|
if declared_type == "INT":
|
||||||
|
results.append(to_int(val))
|
||||||
|
elif declared_type == "FLOAT":
|
||||||
|
results.append(to_float(val))
|
||||||
|
elif isinstance(val, bool):
|
||||||
|
results.append(str(val).lower())
|
||||||
|
elif isinstance(val, int):
|
||||||
|
results.append(val)
|
||||||
|
elif isinstance(val, float):
|
||||||
|
results.append(val)
|
||||||
|
else:
|
||||||
|
results.append(str(val))
|
||||||
|
|
||||||
|
while len(results) < MAX_DYNAMIC_OUTPUTS:
|
||||||
|
results.append("")
|
||||||
|
|
||||||
|
return (total_sequences,) + tuple(results)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Mappings ---
|
||||||
|
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||||
|
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||||
|
}
|
||||||
|
|
||||||
|
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||||
|
}
|
||||||
8
state.py
8
state.py
@@ -17,6 +17,11 @@ class AppState:
|
|||||||
live_toggles: dict = field(default_factory=dict)
|
live_toggles: dict = field(default_factory=dict)
|
||||||
show_comfy_monitor: bool = True
|
show_comfy_monitor: bool = True
|
||||||
|
|
||||||
|
# Project DB fields
|
||||||
|
db: Any = None
|
||||||
|
current_project: str = ""
|
||||||
|
db_enabled: bool = False
|
||||||
|
|
||||||
# Set at runtime by main.py / tab_comfy_ng.py
|
# Set at runtime by main.py / tab_comfy_ng.py
|
||||||
_render_main: Any = None
|
_render_main: Any = None
|
||||||
_load_file: Callable | None = None
|
_load_file: Callable | None = None
|
||||||
@@ -29,4 +34,7 @@ class AppState:
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
current_dir=self.current_dir,
|
current_dir=self.current_dir,
|
||||||
snippets=self.snippets,
|
snippets=self.snippets,
|
||||||
|
db=self.db,
|
||||||
|
current_project=self.current_project,
|
||||||
|
db_enabled=self.db_enabled,
|
||||||
)
|
)
|
||||||
|
|||||||
594
tab_batch.py
594
tab_batch.py
@@ -1,594 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
import random
|
|
||||||
import copy
|
|
||||||
from pathlib import Path
|
|
||||||
from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
|
|
||||||
from history_tree import HistoryTree
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"}
|
|
||||||
|
|
||||||
SUB_SEGMENT_MULTIPLIER = 1000
|
|
||||||
|
|
||||||
def is_subsegment(seq_num):
|
|
||||||
"""Return True if seq_num is a sub-segment (>= 1000)."""
|
|
||||||
return int(seq_num) >= SUB_SEGMENT_MULTIPLIER
|
|
||||||
|
|
||||||
def parent_of(seq_num):
|
|
||||||
"""Return the parent segment number (or self if already a parent)."""
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
return seq_num // SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else seq_num
|
|
||||||
|
|
||||||
def sub_index_of(seq_num):
|
|
||||||
"""Return the sub-index (0 if parent)."""
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
return seq_num % SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else 0
|
|
||||||
|
|
||||||
def format_seq_label(seq_num):
|
|
||||||
"""Return display label: 'Sequence #3' or 'Sub #2.1'."""
|
|
||||||
seq_num = int(seq_num)
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
return f"Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)}"
|
|
||||||
return f"Sequence #{seq_num}"
|
|
||||||
|
|
||||||
def next_sub_segment_number(batch_list, parent_seq_num):
|
|
||||||
"""Find the next available sub-segment number under a parent."""
|
|
||||||
parent_seq_num = int(parent_seq_num)
|
|
||||||
max_sub = 0
|
|
||||||
for s in batch_list:
|
|
||||||
sn = int(s.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if is_subsegment(sn) and parent_of(sn) == parent_seq_num:
|
|
||||||
max_sub = max(max_sub, sub_index_of(sn))
|
|
||||||
return parent_seq_num * SUB_SEGMENT_MULTIPLIER + max_sub + 1
|
|
||||||
|
|
||||||
def find_insert_position(batch_list, parent_index, parent_seq_num):
|
|
||||||
"""Find the insert position after the parent's last existing sub-segment."""
|
|
||||||
parent_seq_num = int(parent_seq_num)
|
|
||||||
pos = parent_index + 1
|
|
||||||
while pos < len(batch_list):
|
|
||||||
sn = int(batch_list[pos].get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if is_subsegment(sn) and parent_of(sn) == parent_seq_num:
|
|
||||||
pos += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return pos
|
|
||||||
|
|
||||||
def _render_mass_update(batch_list, data, file_path, key_prefix):
|
|
||||||
"""Render the mass update UI section."""
|
|
||||||
with st.expander("🔄 Mass Update", expanded=False):
|
|
||||||
if len(batch_list) < 2:
|
|
||||||
st.info("Need at least 2 sequences for mass update.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Source sequence selector
|
|
||||||
source_idx = st.selectbox(
|
|
||||||
"Copy from sequence:",
|
|
||||||
range(len(batch_list)),
|
|
||||||
format_func=lambda i: format_seq_label(batch_list[i].get('sequence_number', i+1)),
|
|
||||||
key=f"{key_prefix}_mass_src"
|
|
||||||
)
|
|
||||||
source_seq = batch_list[source_idx]
|
|
||||||
|
|
||||||
# Field multi-select (exclude sequence_number)
|
|
||||||
available_keys = [k for k in source_seq.keys() if k != "sequence_number"]
|
|
||||||
selected_keys = st.multiselect("Fields to copy:", available_keys, key=f"{key_prefix}_mass_fields")
|
|
||||||
|
|
||||||
if not selected_keys:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Target sequence checkboxes
|
|
||||||
st.write("Apply to:")
|
|
||||||
select_all = st.checkbox("Select All", key=f"{key_prefix}_mass_all")
|
|
||||||
|
|
||||||
target_indices = []
|
|
||||||
target_cols = st.columns(min(4, len(batch_list) - 1)) if len(batch_list) > 1 else [st]
|
|
||||||
col_idx = 0
|
|
||||||
for i, seq in enumerate(batch_list):
|
|
||||||
if i == source_idx:
|
|
||||||
continue
|
|
||||||
seq_num = seq.get("sequence_number", i + 1)
|
|
||||||
with target_cols[col_idx % len(target_cols)]:
|
|
||||||
checked = select_all or st.checkbox(format_seq_label(seq_num), key=f"{key_prefix}_mass_t{i}")
|
|
||||||
if checked:
|
|
||||||
target_indices.append(i)
|
|
||||||
col_idx += 1
|
|
||||||
|
|
||||||
# Preview
|
|
||||||
if target_indices and selected_keys:
|
|
||||||
with st.expander("Preview changes", expanded=True):
|
|
||||||
for key in selected_keys:
|
|
||||||
val = source_seq.get(key, "")
|
|
||||||
display_val = str(val)[:100] + "..." if len(str(val)) > 100 else str(val)
|
|
||||||
st.caption(f"**{key}**: {display_val}")
|
|
||||||
|
|
||||||
# Apply button
|
|
||||||
if st.button("Apply Changes", type="primary", key=f"{key_prefix}_mass_apply"):
|
|
||||||
for i in target_indices:
|
|
||||||
for key in selected_keys:
|
|
||||||
batch_list[i][key] = copy.deepcopy(source_seq.get(key))
|
|
||||||
|
|
||||||
# Save with history snapshot
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
htree = HistoryTree(data.get(KEY_HISTORY_TREE, {}))
|
|
||||||
snapshot_payload = copy.deepcopy(data)
|
|
||||||
if KEY_HISTORY_TREE in snapshot_payload:
|
|
||||||
del snapshot_payload[KEY_HISTORY_TREE]
|
|
||||||
htree.commit(snapshot_payload, f"Mass update: {', '.join(selected_keys)}")
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.data_cache = data
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast(f"Updated {len(target_indices)} sequences", icon="✅")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
|
|
||||||
def create_batch_callback(original_filename, current_data, current_dir):
|
|
||||||
new_name = f"batch_{original_filename}"
|
|
||||||
new_path = current_dir / new_name
|
|
||||||
|
|
||||||
if new_path.exists():
|
|
||||||
st.toast(f"File {new_name} already exists!", icon="⚠️")
|
|
||||||
return
|
|
||||||
|
|
||||||
first_item = current_data.copy()
|
|
||||||
if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY]
|
|
||||||
if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE]
|
|
||||||
|
|
||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
|
||||||
|
|
||||||
new_data = {
|
|
||||||
KEY_BATCH_DATA: [first_item],
|
|
||||||
KEY_HISTORY_TREE: {},
|
|
||||||
KEY_PROMPT_HISTORY: []
|
|
||||||
}
|
|
||||||
|
|
||||||
save_json(new_path, new_data)
|
|
||||||
st.toast(f"Created {new_name}", icon="✨")
|
|
||||||
st.session_state.file_selector = new_name
|
|
||||||
|
|
||||||
|
|
||||||
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
|
||||||
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
|
|
||||||
|
|
||||||
if not is_batch_file:
|
|
||||||
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
|
||||||
st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir))
|
|
||||||
return
|
|
||||||
|
|
||||||
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
|
||||||
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
|
||||||
|
|
||||||
batch_list = data.get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
# --- ADD NEW SEQUENCE AREA ---
|
|
||||||
st.subheader("Add New Sequence")
|
|
||||||
ac1, ac2 = st.columns(2)
|
|
||||||
|
|
||||||
with ac1:
|
|
||||||
file_options = [f.name for f in json_files]
|
|
||||||
d_idx = file_options.index(selected_file_name) if selected_file_name in file_options else 0
|
|
||||||
src_name = st.selectbox("Source File:", file_options, index=d_idx, key="batch_src_file")
|
|
||||||
src_data, _ = load_json(current_dir / src_name)
|
|
||||||
|
|
||||||
with ac2:
|
|
||||||
src_batch = src_data.get(KEY_BATCH_DATA, [])
|
|
||||||
if src_batch:
|
|
||||||
seq_opts = list(range(len(src_batch)))
|
|
||||||
sel_seq_idx = st.selectbox(
|
|
||||||
"Source Sequence:",
|
|
||||||
seq_opts,
|
|
||||||
format_func=lambda i: format_seq_label(src_batch[i].get(KEY_SEQUENCE_NUMBER, i + 1)),
|
|
||||||
key="batch_src_seq"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
st.caption("Single file (no sequences)")
|
|
||||||
sel_seq_idx = None
|
|
||||||
|
|
||||||
bc1, bc2 = st.columns(2)
|
|
||||||
|
|
||||||
def add_sequence(new_item):
|
|
||||||
max_seq = 0
|
|
||||||
for s in batch_list:
|
|
||||||
sn = int(s.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if not is_subsegment(sn):
|
|
||||||
max_seq = max(max_seq, sn)
|
|
||||||
new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1
|
|
||||||
|
|
||||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]:
|
|
||||||
if k in new_item: del new_item[k]
|
|
||||||
|
|
||||||
batch_list.append(new_item)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
if bc1.button("➕ Add Empty", use_container_width=True):
|
|
||||||
add_sequence(DEFAULTS.copy())
|
|
||||||
|
|
||||||
if bc2.button("➕ From Source", use_container_width=True, help=f"Import from {src_name}"):
|
|
||||||
item = DEFAULTS.copy()
|
|
||||||
if src_batch and sel_seq_idx is not None:
|
|
||||||
item.update(src_batch[sel_seq_idx])
|
|
||||||
else:
|
|
||||||
item.update(src_data)
|
|
||||||
add_sequence(item)
|
|
||||||
|
|
||||||
# --- RENDER LIST ---
|
|
||||||
st.markdown("---")
|
|
||||||
info_col, reorder_col = st.columns([3, 1])
|
|
||||||
info_col.info(f"Batch contains {len(batch_list)} sequences.")
|
|
||||||
if reorder_col.button("🔢 Sort by Number", use_container_width=True, help="Reorder sequences by sequence number"):
|
|
||||||
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast("Sorted by sequence number!", icon="🔢")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- MASS UPDATE SECTION ---
|
|
||||||
ui_reset_token = st.session_state.get("ui_reset_token", 0)
|
|
||||||
_render_mass_update(batch_list, data, file_path, f"{selected_file_name}_v{ui_reset_token}")
|
|
||||||
|
|
||||||
# Updated LoRA keys to match new logic
|
|
||||||
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
|
||||||
standard_keys = {
|
|
||||||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "cfg",
|
|
||||||
"camera", "flf", KEY_SEQUENCE_NUMBER
|
|
||||||
}
|
|
||||||
standard_keys.update(lora_keys)
|
|
||||||
standard_keys.update([
|
|
||||||
"frame_to_skip", "end_frame", "transition", "vace_length",
|
|
||||||
"input_a_frames", "input_b_frames", "reference switch", "vace schedule",
|
|
||||||
"reference path", "video file path", "reference image path", "flf image path"
|
|
||||||
])
|
|
||||||
|
|
||||||
VACE_MODES = [
|
|
||||||
"End Extend", "Pre Extend", "Middle Extend", "Edge Extend",
|
|
||||||
"Join Extend", "Bidirectional Extend", "Frame Interpolation",
|
|
||||||
"Replace/Inpaint", "Video Inpaint", "Keyframe",
|
|
||||||
]
|
|
||||||
VACE_FORMULAS = [
|
|
||||||
"base + A", # 0 End Extend
|
|
||||||
"base + B", # 1 Pre Extend
|
|
||||||
"base + A + B", # 2 Middle Extend
|
|
||||||
"base + A + B", # 3 Edge Extend
|
|
||||||
"base + A + B", # 4 Join Extend
|
|
||||||
"base + A + B", # 5 Bidirectional
|
|
||||||
"(B-1) * step", # 6 Frame Interpolation
|
|
||||||
"snap(source)", # 7 Replace/Inpaint
|
|
||||||
"snap(source)", # 8 Video Inpaint
|
|
||||||
"base + A + B", # 9 Keyframe
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, seq in enumerate(batch_list):
|
|
||||||
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1)
|
|
||||||
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
|
||||||
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
expander_label = f"🔗 ↳ Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)} ({int(seq_num)})"
|
|
||||||
else:
|
|
||||||
expander_label = f"🎬 Sequence #{seq_num}"
|
|
||||||
|
|
||||||
with st.expander(expander_label, expanded=False):
|
|
||||||
# --- ACTION ROW ---
|
|
||||||
act_c1, act_c2, act_c3, act_c4 = st.columns([1.2, 1.8, 1.2, 0.5])
|
|
||||||
|
|
||||||
# 1. Copy Source
|
|
||||||
with act_c1:
|
|
||||||
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
|
||||||
item = DEFAULTS.copy()
|
|
||||||
if src_batch and sel_seq_idx is not None:
|
|
||||||
item.update(src_batch[sel_seq_idx])
|
|
||||||
else:
|
|
||||||
item.update(src_data)
|
|
||||||
item[KEY_SEQUENCE_NUMBER] = seq_num
|
|
||||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]:
|
|
||||||
if k in item: del item[k]
|
|
||||||
batch_list[i] = item
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast("Copied!", icon="📥")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 2. Cloning Tools
|
|
||||||
with act_c2:
|
|
||||||
cl_1, cl_2, cl_3 = st.columns(3)
|
|
||||||
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
|
||||||
new_seq = copy.deepcopy(seq)
|
|
||||||
max_sn = 0
|
|
||||||
for s in batch_list:
|
|
||||||
sn = int(s.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if not is_subsegment(sn):
|
|
||||||
max_sn = max(max_sn, sn)
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
|
||||||
if not is_subsegment(seq_num):
|
|
||||||
insert_pos = find_insert_position(batch_list, i, int(seq_num))
|
|
||||||
else:
|
|
||||||
insert_pos = i + 1
|
|
||||||
batch_list.insert(insert_pos, new_seq)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast("Cloned to Next!", icon="👯")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
|
||||||
new_seq = copy.deepcopy(seq)
|
|
||||||
max_sn = 0
|
|
||||||
for s in batch_list:
|
|
||||||
sn = int(s.get(KEY_SEQUENCE_NUMBER, 0))
|
|
||||||
if not is_subsegment(sn):
|
|
||||||
max_sn = max(max_sn, sn)
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
|
||||||
batch_list.append(new_seq)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast("Cloned to End!", icon="⏬")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
if cl_3.button("🔗 Sub", key=f"{prefix}_c_sub", help="Clone as sub-segment", use_container_width=True):
|
|
||||||
new_seq = copy.deepcopy(seq)
|
|
||||||
p_seq_num = parent_of(seq_num)
|
|
||||||
# Find the parent's index in batch_list
|
|
||||||
p_idx = i
|
|
||||||
if is_subsegment(seq_num):
|
|
||||||
for pi, ps in enumerate(batch_list):
|
|
||||||
if int(ps.get(KEY_SEQUENCE_NUMBER, 0)) == p_seq_num:
|
|
||||||
p_idx = pi
|
|
||||||
break
|
|
||||||
new_seq[KEY_SEQUENCE_NUMBER] = next_sub_segment_number(batch_list, p_seq_num)
|
|
||||||
insert_pos = find_insert_position(batch_list, p_idx, p_seq_num)
|
|
||||||
batch_list.insert(insert_pos, new_seq)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast(f"Created {format_seq_label(new_seq[KEY_SEQUENCE_NUMBER])}!", icon="🔗")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 3. Promote
|
|
||||||
with act_c3:
|
|
||||||
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
|
||||||
single_data = seq.copy()
|
|
||||||
single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, [])
|
|
||||||
single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {})
|
|
||||||
if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER]
|
|
||||||
save_json(file_path, single_data)
|
|
||||||
st.session_state.data_cache = single_data
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast("Converted to Single!", icon="✅")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# 4. Remove
|
|
||||||
with act_c4:
|
|
||||||
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
|
||||||
batch_list.pop(i)
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
c1, c2 = st.columns([2, 1])
|
|
||||||
with c1:
|
|
||||||
seq["general_prompt"] = st.text_area("General Prompt", value=seq.get("general_prompt", ""), height=60, key=f"{prefix}_gp")
|
|
||||||
seq["general_negative"] = st.text_area("General Negative", value=seq.get("general_negative", ""), height=60, key=f"{prefix}_gn")
|
|
||||||
seq["current_prompt"] = st.text_area("Specific Prompt", value=seq.get("current_prompt", ""), height=300, key=f"{prefix}_sp")
|
|
||||||
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
|
||||||
|
|
||||||
with c2:
|
|
||||||
sn_label = f"Sequence Number (↳ Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)})" if is_subsegment(seq_num) else "Sequence Number"
|
|
||||||
seq[KEY_SEQUENCE_NUMBER] = st.number_input(sn_label, value=int(seq_num), key=f"{prefix}_sn_val")
|
|
||||||
|
|
||||||
s_row1, s_row2 = st.columns([3, 1])
|
|
||||||
seed_key = f"{prefix}_seed"
|
|
||||||
with s_row2:
|
|
||||||
st.write("")
|
|
||||||
st.write("")
|
|
||||||
if st.button("🎲", key=f"{prefix}_rand"):
|
|
||||||
st.session_state[seed_key] = random.randint(0, 999999999999)
|
|
||||||
st.rerun()
|
|
||||||
with s_row1:
|
|
||||||
current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0)))
|
|
||||||
val = st.number_input("Seed", value=current_seed, key=seed_key)
|
|
||||||
seq["seed"] = val
|
|
||||||
|
|
||||||
seq["cfg"] = st.number_input("CFG", value=float(seq.get("cfg", DEFAULTS["cfg"])), step=0.5, format="%.1f", key=f"{prefix}_cfg")
|
|
||||||
seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam")
|
|
||||||
seq["flf"] = st.text_input("FLF", value=str(seq.get("flf", DEFAULTS["flf"])), key=f"{prefix}_flf")
|
|
||||||
|
|
||||||
seq["end_frame"] = st.number_input("End Frame", value=int(seq.get("end_frame", 0)), key=f"{prefix}_ef")
|
|
||||||
seq["video file path"] = st.text_input("Video File Path", value=seq.get("video file path", ""), key=f"{prefix}_vid")
|
|
||||||
for img_label, img_key, img_suffix in [
|
|
||||||
("Reference Image Path", "reference image path", "rip"),
|
|
||||||
("Reference Path", "reference path", "rp"),
|
|
||||||
("FLF Image Path", "flf image path", "flfi"),
|
|
||||||
]:
|
|
||||||
img_col, prev_col = st.columns([5, 1])
|
|
||||||
seq[img_key] = img_col.text_input(img_label, value=seq.get(img_key, ""), key=f"{prefix}_{img_suffix}")
|
|
||||||
img_path = Path(seq[img_key]) if seq[img_key] else None
|
|
||||||
if img_path and img_path.exists() and img_path.suffix.lower() in IMAGE_EXTENSIONS:
|
|
||||||
with prev_col.popover("👁"):
|
|
||||||
st.image(str(img_path), use_container_width=True)
|
|
||||||
with st.expander("VACE Settings"):
|
|
||||||
fts_col, fts_btn = st.columns([3, 1])
|
|
||||||
saved_fts_key = f"{prefix}_fts_saved"
|
|
||||||
if saved_fts_key not in st.session_state:
|
|
||||||
st.session_state[saved_fts_key] = int(seq.get("frame_to_skip", 81))
|
|
||||||
old_fts = st.session_state[saved_fts_key]
|
|
||||||
seq["frame_to_skip"] = fts_col.number_input("Frame to Skip", value=old_fts, key=f"{prefix}_fts")
|
|
||||||
delta = int(seq["frame_to_skip"]) - old_fts
|
|
||||||
delta_label = f"Shift ↓ ({delta:+d})" if delta != 0 else "Shift ↓ (0)"
|
|
||||||
fts_btn.write("")
|
|
||||||
fts_btn.write("")
|
|
||||||
if fts_btn.button(delta_label, key=f"{prefix}_fts_shift", help="Apply delta to all following sequences", disabled=(delta == 0)):
|
|
||||||
if delta != 0:
|
|
||||||
shifted = 0
|
|
||||||
for j in range(i + 1, len(batch_list)):
|
|
||||||
batch_list[j]["frame_to_skip"] = int(batch_list[j].get("frame_to_skip", 81)) + delta
|
|
||||||
shifted += 1
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.toast(f"Shifted {shifted} sequences by {delta:+d}", icon="⏬")
|
|
||||||
st.rerun()
|
|
||||||
else:
|
|
||||||
st.toast("No change to shift", icon="ℹ️")
|
|
||||||
seq["transition"] = st.text_input("Transition", value=str(seq.get("transition", "1-2")), key=f"{prefix}_trans")
|
|
||||||
|
|
||||||
vs_col, vs_label = st.columns([3, 1])
|
|
||||||
sched_val = int(seq.get("vace schedule", 1))
|
|
||||||
seq["vace schedule"] = vs_col.number_input("VACE Schedule", value=sched_val, min_value=0, max_value=len(VACE_MODES) - 1, key=f"{prefix}_vsc")
|
|
||||||
mode_idx = int(seq["vace schedule"])
|
|
||||||
vs_label.write("")
|
|
||||||
vs_label.write("")
|
|
||||||
vs_label.caption(VACE_MODES[mode_idx])
|
|
||||||
|
|
||||||
with st.popover("📋 Mode Reference"):
|
|
||||||
st.markdown(
|
|
||||||
"| # | Mode | Formula |\n"
|
|
||||||
"|:--|:-----|:--------|\n"
|
|
||||||
+ "\n".join(
|
|
||||||
f"| **{j}** | {VACE_MODES[j]} | `{VACE_FORMULAS[j]}` |"
|
|
||||||
for j in range(len(VACE_MODES))
|
|
||||||
)
|
|
||||||
+ "\n\n*All totals snapped to 4n+1 (1,5,9,…,49,…,81,…)*"
|
|
||||||
)
|
|
||||||
|
|
||||||
seq["input_a_frames"] = st.number_input("Input A Frames", value=int(seq.get("input_a_frames", 16)), key=f"{prefix}_ia")
|
|
||||||
seq["input_b_frames"] = st.number_input("Input B Frames", value=int(seq.get("input_b_frames", 16)), key=f"{prefix}_ib")
|
|
||||||
input_a = int(seq.get("input_a_frames", 16))
|
|
||||||
input_b = int(seq.get("input_b_frames", 16))
|
|
||||||
stored_total = int(seq.get("vace_length", 49))
|
|
||||||
# Reverse using same mode formula that was used to store
|
|
||||||
if mode_idx == 0:
|
|
||||||
base_length = max(stored_total - input_a, 1)
|
|
||||||
elif mode_idx == 1:
|
|
||||||
base_length = max(stored_total - input_b, 1)
|
|
||||||
else:
|
|
||||||
base_length = max(stored_total - input_a - input_b, 1)
|
|
||||||
vl_col, vl_out = st.columns([3, 1])
|
|
||||||
new_base = vl_col.number_input("VACE Length", value=base_length, min_value=1, key=f"{prefix}_vl")
|
|
||||||
if mode_idx == 0: # End Extend: base + A
|
|
||||||
raw_total = new_base + input_a
|
|
||||||
elif mode_idx == 1: # Pre Extend: base + B
|
|
||||||
raw_total = new_base + input_b
|
|
||||||
else: # Most modes: base + A + B
|
|
||||||
raw_total = new_base + input_a + input_b
|
|
||||||
# Snap to 4n+1 (1,5,9,13,...,81,...) to match VACE sampler
|
|
||||||
seq["vace_length"] = ((raw_total + 2) // 4) * 4 + 1
|
|
||||||
vl_out.metric("Output", seq["vace_length"])
|
|
||||||
seq["reference switch"] = st.number_input("Reference Switch", value=int(seq.get("reference switch", 1)), key=f"{prefix}_rsw")
|
|
||||||
|
|
||||||
# --- UPDATED: LoRA Settings with Tag Wrapping ---
|
|
||||||
with st.expander("💊 LoRA Settings"):
|
|
||||||
lc1, lc2, lc3 = st.columns(3)
|
|
||||||
|
|
||||||
# Helper to render the tag wrapper UI
|
|
||||||
def render_lora_col(col_obj, lora_idx):
|
|
||||||
with col_obj:
|
|
||||||
st.caption(f"**LoRA {lora_idx}**")
|
|
||||||
|
|
||||||
# --- HIGH ---
|
|
||||||
k_high = f"lora {lora_idx} high"
|
|
||||||
raw_h = str(seq.get(k_high, ""))
|
|
||||||
# Strip tags for display
|
|
||||||
disp_h = raw_h.replace("<lora:", "").replace(">", "")
|
|
||||||
|
|
||||||
st.write("High:")
|
|
||||||
rh1, rh2, rh3 = st.columns([0.25, 1, 0.1])
|
|
||||||
rh1.markdown("<div style='text-align: right; padding-top: 8px;'><code><lora:</code></div>", unsafe_allow_html=True)
|
|
||||||
val_h = rh2.text_input(f"L{lora_idx}H", value=disp_h, key=f"{prefix}_l{lora_idx}h", label_visibility="collapsed")
|
|
||||||
rh3.markdown("<div style='padding-top: 8px;'><code>></code></div>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
if val_h:
|
|
||||||
seq[k_high] = f"<lora:{val_h}>"
|
|
||||||
else:
|
|
||||||
seq[k_high] = ""
|
|
||||||
|
|
||||||
# --- LOW ---
|
|
||||||
k_low = f"lora {lora_idx} low"
|
|
||||||
raw_l = str(seq.get(k_low, ""))
|
|
||||||
# Strip tags for display
|
|
||||||
disp_l = raw_l.replace("<lora:", "").replace(">", "")
|
|
||||||
|
|
||||||
st.write("Low:")
|
|
||||||
rl1, rl2, rl3 = st.columns([0.25, 1, 0.1])
|
|
||||||
rl1.markdown("<div style='text-align: right; padding-top: 8px;'><code><lora:</code></div>", unsafe_allow_html=True)
|
|
||||||
val_l = rl2.text_input(f"L{lora_idx}L", value=disp_l, key=f"{prefix}_l{lora_idx}l", label_visibility="collapsed")
|
|
||||||
rl3.markdown("<div style='padding-top: 8px;'><code>></code></div>", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
if val_l:
|
|
||||||
seq[k_low] = f"<lora:{val_l}>"
|
|
||||||
else:
|
|
||||||
seq[k_low] = ""
|
|
||||||
|
|
||||||
render_lora_col(lc1, 1)
|
|
||||||
render_lora_col(lc2, 2)
|
|
||||||
render_lora_col(lc3, 3)
|
|
||||||
|
|
||||||
# --- CUSTOM PARAMETERS ---
|
|
||||||
st.markdown("---")
|
|
||||||
st.caption("🔧 Custom Parameters")
|
|
||||||
|
|
||||||
custom_keys = [k for k in seq.keys() if k not in standard_keys]
|
|
||||||
keys_to_remove = []
|
|
||||||
|
|
||||||
if custom_keys:
|
|
||||||
for k in custom_keys:
|
|
||||||
ck1, ck2, ck3 = st.columns([1, 2, 0.5])
|
|
||||||
ck1.text_input("Key", value=k, disabled=True, key=f"{prefix}_ck_lbl_{k}", label_visibility="collapsed")
|
|
||||||
val = ck2.text_input("Value", value=str(seq[k]), key=f"{prefix}_cv_{k}", label_visibility="collapsed")
|
|
||||||
seq[k] = val
|
|
||||||
|
|
||||||
if ck3.button("🗑️", key=f"{prefix}_cdel_{k}"):
|
|
||||||
keys_to_remove.append(k)
|
|
||||||
|
|
||||||
with st.expander("➕ Add Parameter"):
|
|
||||||
nk_col, nv_col = st.columns(2)
|
|
||||||
new_k = nk_col.text_input("Key", key=f"{prefix}_new_k")
|
|
||||||
new_v = nv_col.text_input("Value", key=f"{prefix}_new_v")
|
|
||||||
|
|
||||||
if st.button("Add", key=f"{prefix}_add_cust"):
|
|
||||||
if new_k and new_k not in seq:
|
|
||||||
seq[new_k] = new_v
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
if keys_to_remove:
|
|
||||||
for k in keys_to_remove:
|
|
||||||
del seq[k]
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- SAVE ACTIONS WITH HISTORY COMMIT ---
|
|
||||||
col_save, col_note = st.columns([1, 2])
|
|
||||||
|
|
||||||
with col_note:
|
|
||||||
commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3")
|
|
||||||
|
|
||||||
with col_save:
|
|
||||||
if st.button("💾 Save & Snap", use_container_width=True):
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
|
||||||
|
|
||||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
|
||||||
htree = HistoryTree(tree_data)
|
|
||||||
|
|
||||||
snapshot_payload = copy.deepcopy(data)
|
|
||||||
if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE]
|
|
||||||
|
|
||||||
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
|
||||||
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
|
|
||||||
if 'restored_indicator' in st.session_state:
|
|
||||||
del st.session_state.restored_indicator
|
|
||||||
|
|
||||||
st.toast("Batch Saved & Snapshot Created!", icon="🚀")
|
|
||||||
st.rerun()
|
|
||||||
@@ -6,7 +6,7 @@ from nicegui import ui
|
|||||||
|
|
||||||
from state import AppState
|
from state import AppState
|
||||||
from utils import (
|
from utils import (
|
||||||
DEFAULTS, save_json, load_json,
|
DEFAULTS, save_json, load_json, sync_to_db,
|
||||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
||||||
)
|
)
|
||||||
from history_tree import HistoryTree
|
from history_tree import HistoryTree
|
||||||
@@ -161,6 +161,8 @@ def render_batch_processor(state: AppState):
|
|||||||
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
||||||
KEY_PROMPT_HISTORY: []}
|
KEY_PROMPT_HISTORY: []}
|
||||||
save_json(new_path, new_data)
|
save_json(new_path, new_data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, new_path, new_data)
|
||||||
ui.notify(f'Created {new_name}', type='positive')
|
ui.notify(f'Created {new_name}', type='positive')
|
||||||
|
|
||||||
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
||||||
@@ -215,6 +217,8 @@ def render_batch_processor(state: AppState):
|
|||||||
batch_list.append(new_item)
|
batch_list.append(new_item)
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
render_sequence_list.refresh()
|
render_sequence_list.refresh()
|
||||||
|
|
||||||
with ui.row().classes('q-mt-sm'):
|
with ui.row().classes('q-mt-sm'):
|
||||||
@@ -250,6 +254,8 @@ def render_batch_processor(state: AppState):
|
|||||||
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
ui.notify('Sorted by sequence number!', type='positive')
|
ui.notify('Sorted by sequence number!', type='positive')
|
||||||
render_sequence_list.refresh()
|
render_sequence_list.refresh()
|
||||||
|
|
||||||
@@ -289,6 +295,8 @@ def render_batch_processor(state: AppState):
|
|||||||
htree.commit(snapshot_payload, note=note)
|
htree.commit(snapshot_payload, note=note)
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
state.restored_indicator = None
|
state.restored_indicator = None
|
||||||
commit_input.set_value('')
|
commit_input.set_value('')
|
||||||
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
||||||
@@ -306,6 +314,8 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
def commit(message=None):
|
def commit(message=None):
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
if message:
|
if message:
|
||||||
ui.notify(message, type='positive')
|
ui.notify(message, type='positive')
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
@@ -387,7 +397,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
|
|
||||||
ui.separator()
|
ui.separator()
|
||||||
|
|
||||||
# --- Prompts + Settings (2-column like Streamlit) ---
|
# --- Prompts + Settings (2-column) ---
|
||||||
with ui.splitter(value=66).classes('w-full') as splitter:
|
with ui.splitter(value=66).classes('w-full') as splitter:
|
||||||
with splitter.before:
|
with splitter.before:
|
||||||
dict_textarea('General Prompt', seq, 'general_prompt').classes(
|
dict_textarea('General Prompt', seq, 'general_prompt').classes(
|
||||||
@@ -447,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
|
|
||||||
# --- VACE Settings (full width) ---
|
# --- VACE Settings (full width) ---
|
||||||
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
||||||
_render_vace_settings(i, seq, batch_list, data, file_path, refresh_list)
|
_render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
|
||||||
|
|
||||||
# --- LoRA Settings ---
|
# --- LoRA Settings ---
|
||||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
||||||
@@ -529,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
# VACE Settings sub-section
|
# VACE Settings sub-section
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
||||||
def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
||||||
# VACE Schedule (needed early for both columns)
|
# VACE Schedule (needed early for both columns)
|
||||||
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
||||||
|
|
||||||
@@ -567,6 +577,8 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
|||||||
shifted += 1
|
shifted += 1
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
|
|
||||||
@@ -712,6 +724,8 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
|||||||
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
||||||
if refresh_list:
|
if refresh_list:
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
|
|||||||
249
tab_comfy.py
249
tab_comfy.py
@@ -1,249 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
import requests
|
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
import urllib.parse
|
|
||||||
import html
|
|
||||||
import time # <--- NEW IMPORT
|
|
||||||
from utils import save_config
|
|
||||||
|
|
||||||
def render_single_instance(instance_config, index, all_instances, timeout_minutes):
|
|
||||||
url = instance_config.get("url", "http://127.0.0.1:8188")
|
|
||||||
name = instance_config.get("name", f"Server {index+1}")
|
|
||||||
|
|
||||||
COMFY_URL = url.rstrip("/")
|
|
||||||
|
|
||||||
# --- TIMEOUT LOGIC ---
|
|
||||||
# Generate unique keys for session state
|
|
||||||
toggle_key = f"live_toggle_{index}"
|
|
||||||
start_time_key = f"live_start_{index}"
|
|
||||||
|
|
||||||
# Check if we need to auto-close
|
|
||||||
if st.session_state.get(toggle_key, False) and timeout_minutes > 0:
|
|
||||||
start_time = st.session_state.get(start_time_key, 0)
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
if elapsed > (timeout_minutes * 60):
|
|
||||||
st.session_state[toggle_key] = False
|
|
||||||
# We don't need st.rerun() here because the fragment loop will pick up the state change on the next pass
|
|
||||||
# but an explicit rerun makes it snappy.
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
c_head, c_set = st.columns([3, 1])
|
|
||||||
c_head.markdown(f"### 🔌 {name}")
|
|
||||||
|
|
||||||
with c_set.popover("⚙️ Settings"):
|
|
||||||
st.caption("Press Update to apply changes!")
|
|
||||||
new_name = st.text_input("Name", value=name, key=f"name_{index}")
|
|
||||||
new_url = st.text_input("URL", value=url, key=f"url_{index}")
|
|
||||||
|
|
||||||
if new_url != url:
|
|
||||||
st.warning("⚠️ Unsaved URL! Click Update below.")
|
|
||||||
|
|
||||||
if st.button("💾 Update & Save", key=f"save_{index}", type="primary"):
|
|
||||||
all_instances[index]["name"] = new_name
|
|
||||||
all_instances[index]["url"] = new_url
|
|
||||||
st.session_state.config["comfy_instances"] = all_instances
|
|
||||||
|
|
||||||
save_config(
|
|
||||||
st.session_state.current_dir,
|
|
||||||
st.session_state.config['favorites'],
|
|
||||||
st.session_state.config
|
|
||||||
)
|
|
||||||
st.toast("Server config saved!", icon="💾")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.divider()
|
|
||||||
if st.button("🗑️ Remove Server", key=f"del_{index}"):
|
|
||||||
all_instances.pop(index)
|
|
||||||
st.session_state.config["comfy_instances"] = all_instances
|
|
||||||
save_config(
|
|
||||||
st.session_state.current_dir,
|
|
||||||
st.session_state.config['favorites'],
|
|
||||||
st.session_state.config
|
|
||||||
)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- 1. STATUS DASHBOARD ---
|
|
||||||
with st.expander("📊 Server Status", expanded=True):
|
|
||||||
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
|
|
||||||
try:
|
|
||||||
res = requests.get(f"{COMFY_URL}/queue", timeout=1.5)
|
|
||||||
queue_data = res.json()
|
|
||||||
running_cnt = len(queue_data.get("queue_running", []))
|
|
||||||
pending_cnt = len(queue_data.get("queue_pending", []))
|
|
||||||
|
|
||||||
col1.metric("Status", "🟢 Online" if running_cnt > 0 else "💤 Idle")
|
|
||||||
col2.metric("Pending", pending_cnt)
|
|
||||||
col3.metric("Running", running_cnt)
|
|
||||||
|
|
||||||
if col4.button("🔄 Check Img", key=f"refresh_{index}", use_container_width=True):
|
|
||||||
st.session_state[f"force_img_refresh_{index}"] = True
|
|
||||||
except Exception:
|
|
||||||
col1.metric("Status", "🔴 Offline")
|
|
||||||
col2.metric("Pending", "-")
|
|
||||||
col3.metric("Running", "-")
|
|
||||||
st.error(f"Could not connect to API at {COMFY_URL}")
|
|
||||||
|
|
||||||
# --- 2. LIVE VIEW (VIA REMOTE BROWSER) ---
|
|
||||||
st.write("")
|
|
||||||
c_label, c_ctrl = st.columns([1, 2])
|
|
||||||
c_label.subheader("📺 Live View")
|
|
||||||
|
|
||||||
# Capture the toggle interaction to set start time
|
|
||||||
def on_toggle_change():
|
|
||||||
if st.session_state[toggle_key]:
|
|
||||||
st.session_state[start_time_key] = time.time()
|
|
||||||
|
|
||||||
enable_preview = c_ctrl.checkbox(
|
|
||||||
"Enable Live Preview",
|
|
||||||
value=False,
|
|
||||||
key=toggle_key,
|
|
||||||
on_change=on_toggle_change
|
|
||||||
)
|
|
||||||
|
|
||||||
if enable_preview:
|
|
||||||
# Display Countdown if timeout is active
|
|
||||||
if timeout_minutes > 0:
|
|
||||||
elapsed = time.time() - st.session_state.get(start_time_key, time.time())
|
|
||||||
remaining = (timeout_minutes * 60) - elapsed
|
|
||||||
st.caption(f"⏱️ Auto-off in: **{int(remaining)}s**")
|
|
||||||
|
|
||||||
# Height Slider
|
|
||||||
iframe_h = st.slider(
|
|
||||||
"Height (px)",
|
|
||||||
min_value=600, max_value=2500, value=1000, step=50,
|
|
||||||
key=f"h_slider_{index}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get Configured Viewer URL
|
|
||||||
viewer_base = st.session_state.config.get("viewer_url", "")
|
|
||||||
final_src = viewer_base.strip()
|
|
||||||
|
|
||||||
# Validate URL scheme before embedding
|
|
||||||
parsed = urllib.parse.urlparse(final_src)
|
|
||||||
if final_src and parsed.scheme in ("http", "https"):
|
|
||||||
safe_src = html.escape(final_src, quote=True)
|
|
||||||
st.info(f"Viewing via Remote Browser: `{final_src}`")
|
|
||||||
st.markdown(
|
|
||||||
f"""
|
|
||||||
<iframe src="{safe_src}" width="100%" height="{iframe_h}px"
|
|
||||||
style="border: 2px solid #666; border-radius: 8px; box-shadow: 0 4px 6px rgba(0,0,0,0.3);">
|
|
||||||
</iframe>
|
|
||||||
""",
|
|
||||||
unsafe_allow_html=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
st.warning("No valid viewer URL configured. Set one in Monitor Settings below.")
|
|
||||||
else:
|
|
||||||
st.info("Live Preview is disabled.")
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- 3. LATEST OUTPUT ---
|
|
||||||
if st.session_state.get(f"force_img_refresh_{index}", False):
|
|
||||||
st.caption("🖼️ Most Recent Output")
|
|
||||||
try:
|
|
||||||
hist_res = requests.get(f"{COMFY_URL}/history", timeout=2)
|
|
||||||
history = hist_res.json()
|
|
||||||
if history:
|
|
||||||
last_prompt_id = list(history.keys())[-1]
|
|
||||||
outputs = history[last_prompt_id].get("outputs", {})
|
|
||||||
found_img = None
|
|
||||||
for node_id, node_output in outputs.items():
|
|
||||||
if "images" in node_output:
|
|
||||||
for img_info in node_output["images"]:
|
|
||||||
if img_info["type"] == "output":
|
|
||||||
found_img = img_info
|
|
||||||
break
|
|
||||||
if found_img: break
|
|
||||||
|
|
||||||
if found_img:
|
|
||||||
img_name = found_img['filename']
|
|
||||||
folder = found_img['subfolder']
|
|
||||||
img_type = found_img['type']
|
|
||||||
img_url = f"{COMFY_URL}/view?filename={img_name}&subfolder={folder}&type={img_type}"
|
|
||||||
img_res = requests.get(img_url)
|
|
||||||
image = Image.open(BytesIO(img_res.content))
|
|
||||||
st.image(image, caption=f"Last Output: {img_name}")
|
|
||||||
else:
|
|
||||||
st.warning("Last run had no image output.")
|
|
||||||
else:
|
|
||||||
st.info("No history found.")
|
|
||||||
st.session_state[f"force_img_refresh_{index}"] = False
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error fetching image: {e}")
|
|
||||||
|
|
||||||
# Check for fragment support (Streamlit 1.37+)
|
|
||||||
if hasattr(st, "fragment"):
|
|
||||||
# This decorator ensures this function re-runs every 10 seconds automatically
|
|
||||||
# allowing it to catch the timeout even if you are away from the keyboard.
|
|
||||||
@st.fragment(run_every=300)
|
|
||||||
def _monitor_fragment():
|
|
||||||
_render_content()
|
|
||||||
else:
|
|
||||||
# Fallback for older Streamlit versions (Won't auto-refresh while idle)
|
|
||||||
def _monitor_fragment():
|
|
||||||
_render_content()
|
|
||||||
|
|
||||||
def _render_content():
|
|
||||||
# --- GLOBAL SETTINGS FOR MONITOR ---
|
|
||||||
with st.expander("🔧 Monitor Settings", expanded=False):
|
|
||||||
c_set1, c_set2 = st.columns(2)
|
|
||||||
|
|
||||||
current_viewer = st.session_state.config.get("viewer_url", "")
|
|
||||||
new_viewer = c_set1.text_input("Remote Browser URL", value=current_viewer, help="e.g., http://localhost:5800")
|
|
||||||
|
|
||||||
# New Timeout Slider
|
|
||||||
current_timeout = st.session_state.config.get("monitor_timeout", 0)
|
|
||||||
new_timeout = c_set2.slider("Live Preview Timeout (Minutes)", 0, 60, value=current_timeout, help="0 = Always On. Sets how long the preview stays open before auto-closing.")
|
|
||||||
|
|
||||||
if st.button("💾 Save Monitor Settings"):
|
|
||||||
st.session_state.config["viewer_url"] = new_viewer
|
|
||||||
st.session_state.config["monitor_timeout"] = new_timeout
|
|
||||||
save_config(
|
|
||||||
st.session_state.current_dir,
|
|
||||||
st.session_state.config['favorites'],
|
|
||||||
st.session_state.config
|
|
||||||
)
|
|
||||||
st.success("Settings saved!")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- INSTANCE MANAGEMENT ---
|
|
||||||
if "comfy_instances" not in st.session_state.config:
|
|
||||||
st.session_state.config["comfy_instances"] = [
|
|
||||||
{"name": "Main Server", "url": "http://192.168.1.100:8188"}
|
|
||||||
]
|
|
||||||
|
|
||||||
instances = st.session_state.config["comfy_instances"]
|
|
||||||
tab_names = [i["name"] for i in instances] + ["➕ Add Server"]
|
|
||||||
tabs = st.tabs(tab_names)
|
|
||||||
|
|
||||||
timeout_val = st.session_state.config.get("monitor_timeout", 0)
|
|
||||||
|
|
||||||
for i, tab in enumerate(tabs[:-1]):
|
|
||||||
with tab:
|
|
||||||
render_single_instance(instances[i], i, instances, timeout_val)
|
|
||||||
|
|
||||||
with tabs[-1]:
|
|
||||||
st.header("Add New ComfyUI Instance")
|
|
||||||
with st.form("add_server_form"):
|
|
||||||
new_name = st.text_input("Server Name", placeholder="e.g. Render Node 2")
|
|
||||||
new_url = st.text_input("URL", placeholder="http://192.168.1.50:8188")
|
|
||||||
if st.form_submit_button("Add Instance"):
|
|
||||||
if new_name and new_url:
|
|
||||||
instances.append({"name": new_name, "url": new_url})
|
|
||||||
st.session_state.config["comfy_instances"] = instances
|
|
||||||
|
|
||||||
save_config(
|
|
||||||
st.session_state.current_dir,
|
|
||||||
st.session_state.config['favorites'],
|
|
||||||
st.session_state.config
|
|
||||||
)
|
|
||||||
st.success("Server Added!")
|
|
||||||
st.rerun()
|
|
||||||
else:
|
|
||||||
st.error("Please fill in both Name and URL.")
|
|
||||||
|
|
||||||
def render_comfy_monitor():
|
|
||||||
# We call the wrapper which decides if it's a fragment or not
|
|
||||||
_monitor_fragment()
|
|
||||||
165
tab_projects_ng.py
Normal file
165
tab_projects_ng.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nicegui import ui
|
||||||
|
|
||||||
|
from state import AppState
|
||||||
|
from db import ProjectDB
|
||||||
|
from utils import save_config, sync_to_db, KEY_BATCH_DATA
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def render_projects_tab(state: AppState):
|
||||||
|
"""Render the Projects management tab."""
|
||||||
|
|
||||||
|
# --- DB toggle ---
|
||||||
|
def on_db_toggle(e):
|
||||||
|
state.db_enabled = e.value
|
||||||
|
state.config['db_enabled'] = e.value
|
||||||
|
save_config(state.current_dir, state.config.get('favorites', []), state.config)
|
||||||
|
render_project_content.refresh()
|
||||||
|
|
||||||
|
ui.switch('Enable Project Database', value=state.db_enabled,
|
||||||
|
on_change=on_db_toggle).classes('q-mb-md')
|
||||||
|
|
||||||
|
@ui.refreshable
|
||||||
|
def render_project_content():
|
||||||
|
if not state.db_enabled:
|
||||||
|
ui.label('Project database is disabled. Enable it above to manage projects.').classes(
|
||||||
|
'text-caption q-pa-md')
|
||||||
|
return
|
||||||
|
|
||||||
|
if not state.db:
|
||||||
|
ui.label('Database not initialized.').classes('text-warning q-pa-md')
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Create project form ---
|
||||||
|
with ui.card().classes('w-full q-pa-md q-mb-md'):
|
||||||
|
ui.label('Create New Project').classes('section-header')
|
||||||
|
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
|
||||||
|
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
|
||||||
|
|
||||||
|
def create_project():
|
||||||
|
name = name_input.value.strip()
|
||||||
|
if not name:
|
||||||
|
ui.notify('Please enter a project name', type='warning')
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
state.db.create_project(name, str(state.current_dir), desc_input.value.strip())
|
||||||
|
name_input.set_value('')
|
||||||
|
desc_input.set_value('')
|
||||||
|
ui.notify(f'Created project "{name}"', type='positive')
|
||||||
|
render_project_list.refresh()
|
||||||
|
except Exception as e:
|
||||||
|
ui.notify(f'Error: {e}', type='negative')
|
||||||
|
|
||||||
|
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
|
||||||
|
|
||||||
|
# --- Active project indicator ---
|
||||||
|
if state.current_project:
|
||||||
|
ui.label(f'Active Project: {state.current_project}').classes(
|
||||||
|
'text-bold text-primary q-pa-sm')
|
||||||
|
|
||||||
|
# --- Project list ---
|
||||||
|
@ui.refreshable
|
||||||
|
def render_project_list():
|
||||||
|
projects = state.db.list_projects()
|
||||||
|
if not projects:
|
||||||
|
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
|
||||||
|
return
|
||||||
|
|
||||||
|
for proj in projects:
|
||||||
|
is_active = proj['name'] == state.current_project
|
||||||
|
card_style = 'border-left: 3px solid var(--accent);' if is_active else ''
|
||||||
|
|
||||||
|
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(card_style):
|
||||||
|
with ui.row().classes('w-full items-center'):
|
||||||
|
with ui.column().classes('col'):
|
||||||
|
ui.label(proj['name']).classes('text-bold')
|
||||||
|
if proj['description']:
|
||||||
|
ui.label(proj['description']).classes('text-caption')
|
||||||
|
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
|
||||||
|
files = state.db.list_data_files(proj['id'])
|
||||||
|
ui.label(f'{len(files)} data file(s)').classes('text-caption')
|
||||||
|
|
||||||
|
with ui.row().classes('q-gutter-xs'):
|
||||||
|
if not is_active:
|
||||||
|
def activate(name=proj['name']):
|
||||||
|
state.current_project = name
|
||||||
|
state.config['current_project'] = name
|
||||||
|
save_config(state.current_dir,
|
||||||
|
state.config.get('favorites', []),
|
||||||
|
state.config)
|
||||||
|
ui.notify(f'Activated project "{name}"', type='positive')
|
||||||
|
render_project_list.refresh()
|
||||||
|
|
||||||
|
ui.button('Activate', icon='check_circle',
|
||||||
|
on_click=activate).props('flat dense color=primary')
|
||||||
|
else:
|
||||||
|
def deactivate():
|
||||||
|
state.current_project = ''
|
||||||
|
state.config['current_project'] = ''
|
||||||
|
save_config(state.current_dir,
|
||||||
|
state.config.get('favorites', []),
|
||||||
|
state.config)
|
||||||
|
ui.notify('Deactivated project', type='info')
|
||||||
|
render_project_list.refresh()
|
||||||
|
|
||||||
|
ui.button('Deactivate', icon='cancel',
|
||||||
|
on_click=deactivate).props('flat dense')
|
||||||
|
|
||||||
|
def import_folder(pid=proj['id'], pname=proj['name']):
|
||||||
|
_import_folder(state, pid, pname, render_project_list)
|
||||||
|
|
||||||
|
ui.button('Import Folder', icon='folder_open',
|
||||||
|
on_click=import_folder).props('flat dense')
|
||||||
|
|
||||||
|
def delete_proj(name=proj['name']):
|
||||||
|
state.db.delete_project(name)
|
||||||
|
if state.current_project == name:
|
||||||
|
state.current_project = ''
|
||||||
|
state.config['current_project'] = ''
|
||||||
|
save_config(state.current_dir,
|
||||||
|
state.config.get('favorites', []),
|
||||||
|
state.config)
|
||||||
|
ui.notify(f'Deleted project "{name}"', type='positive')
|
||||||
|
render_project_list.refresh()
|
||||||
|
|
||||||
|
ui.button(icon='delete',
|
||||||
|
on_click=delete_proj).props('flat dense color=negative')
|
||||||
|
|
||||||
|
render_project_list()
|
||||||
|
|
||||||
|
render_project_content()
|
||||||
|
|
||||||
|
|
||||||
|
def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
|
||||||
|
"""Bulk import all .json files from current directory into a project."""
|
||||||
|
json_files = sorted(state.current_dir.glob('*.json'))
|
||||||
|
json_files = [f for f in json_files if f.name not in (
|
||||||
|
'.editor_config.json', '.editor_snippets.json')]
|
||||||
|
|
||||||
|
if not json_files:
|
||||||
|
ui.notify('No JSON files in current directory', type='warning')
|
||||||
|
return
|
||||||
|
|
||||||
|
imported = 0
|
||||||
|
skipped = 0
|
||||||
|
for jf in json_files:
|
||||||
|
file_name = jf.stem
|
||||||
|
existing = state.db.get_data_file(project_id, file_name)
|
||||||
|
if existing:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
state.db.import_json_file(project_id, jf)
|
||||||
|
imported += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to import {jf}: {e}")
|
||||||
|
|
||||||
|
msg = f'Imported {imported} file(s)'
|
||||||
|
if skipped:
|
||||||
|
msg += f', skipped {skipped} existing'
|
||||||
|
ui.notify(msg, type='positive')
|
||||||
|
refresh_fn.refresh()
|
||||||
78
tab_raw.py
78
tab_raw.py
@@ -1,78 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
|
||||||
|
|
||||||
def render_raw_editor(data, file_path):
|
|
||||||
st.subheader(f"💻 Raw Editor: {file_path.name}")
|
|
||||||
|
|
||||||
# Toggle to hide massive history objects
|
|
||||||
# This is crucial because history trees can get huge and make the text area laggy.
|
|
||||||
col_ctrl, col_info = st.columns([1, 2])
|
|
||||||
with col_ctrl:
|
|
||||||
hide_history = st.checkbox(
|
|
||||||
"Hide History (Safe Mode)",
|
|
||||||
value=True,
|
|
||||||
help="Hides 'history_tree' and 'prompt_history' to keep the editor fast and prevent accidental deletion of version control."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare display data
|
|
||||||
if hide_history:
|
|
||||||
display_data = copy.deepcopy(data)
|
|
||||||
# Safely remove heavy keys for the view only
|
|
||||||
if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE]
|
|
||||||
if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY]
|
|
||||||
else:
|
|
||||||
display_data = data
|
|
||||||
|
|
||||||
# Convert to string
|
|
||||||
# ensure_ascii=False ensures emojis and special chars render correctly
|
|
||||||
try:
|
|
||||||
json_str = json.dumps(display_data, indent=4, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error serializing JSON: {e}")
|
|
||||||
json_str = "{}"
|
|
||||||
|
|
||||||
# The Text Editor
|
|
||||||
# We use ui_reset_token in the key to force the text area to reload content on save
|
|
||||||
new_json_str = st.text_area(
|
|
||||||
"JSON Content",
|
|
||||||
value=json_str,
|
|
||||||
height=650,
|
|
||||||
key=f"raw_edit_{file_path.name}_{st.session_state.ui_reset_token}"
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
if st.button("💾 Save Raw Changes", type="primary", use_container_width=True):
|
|
||||||
try:
|
|
||||||
# 1. Parse the text back to JSON
|
|
||||||
input_data = json.loads(new_json_str)
|
|
||||||
|
|
||||||
# 2. If we were in Safe Mode, we must merge the hidden history back in
|
|
||||||
if hide_history:
|
|
||||||
if KEY_HISTORY_TREE in data:
|
|
||||||
input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
|
|
||||||
if KEY_PROMPT_HISTORY in data:
|
|
||||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
|
||||||
|
|
||||||
# 3. Save to Disk
|
|
||||||
save_json(file_path, input_data)
|
|
||||||
|
|
||||||
# 4. Update Session State
|
|
||||||
# We clear and update the existing dictionary object so other tabs see the changes
|
|
||||||
data.clear()
|
|
||||||
data.update(input_data)
|
|
||||||
|
|
||||||
# 5. Update Metadata to prevent conflict warnings
|
|
||||||
st.session_state.last_mtime = get_file_mtime(file_path)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
|
|
||||||
st.toast("Raw JSON Saved Successfully!", icon="✅")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
st.error(f"❌ Invalid JSON Syntax: {e}")
|
|
||||||
st.error("Please fix the formatting errors above before saving.")
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"❌ Unexpected Error: {e}")
|
|
||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
from nicegui import ui
|
from nicegui import ui
|
||||||
|
|
||||||
from state import AppState
|
from state import AppState
|
||||||
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
from utils import save_json, sync_to_db, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
||||||
|
|
||||||
|
|
||||||
def render_raw_editor(state: AppState):
|
def render_raw_editor(state: AppState):
|
||||||
@@ -52,6 +52,8 @@ def render_raw_editor(state: AppState):
|
|||||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
||||||
|
|
||||||
save_json(file_path, input_data)
|
save_json(file_path, input_data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, input_data)
|
||||||
|
|
||||||
data.clear()
|
data.clear()
|
||||||
data.update(input_data)
|
data.update(input_data)
|
||||||
|
|||||||
390
tab_timeline.py
390
tab_timeline.py
@@ -1,390 +0,0 @@
|
|||||||
import streamlit as st
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
from history_tree import HistoryTree
|
|
||||||
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
|
||||||
|
|
||||||
try:
|
|
||||||
from streamlit_agraph import agraph, Node, Edge, Config
|
|
||||||
AGRAPH_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
AGRAPH_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
def render_timeline_tab(data, file_path):
|
|
||||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
|
||||||
if not tree_data:
|
|
||||||
st.info("No history timeline exists. Make some changes in the Editor first!")
|
|
||||||
return
|
|
||||||
|
|
||||||
htree = HistoryTree(tree_data)
|
|
||||||
|
|
||||||
# --- Initialize selection state ---
|
|
||||||
if "timeline_selected_nodes" not in st.session_state:
|
|
||||||
st.session_state.timeline_selected_nodes = set()
|
|
||||||
|
|
||||||
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
|
||||||
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
|
||||||
|
|
||||||
# --- VIEW SWITCHER + SELECTION MODE ---
|
|
||||||
c_title, c_view, c_toggle = st.columns([2, 1, 0.6])
|
|
||||||
c_title.subheader("🕰️ Version History")
|
|
||||||
|
|
||||||
view_mode = c_view.radio(
|
|
||||||
"View Mode",
|
|
||||||
["🌳 Horizontal", "🌲 Vertical", "📜 Linear Log"],
|
|
||||||
horizontal=True,
|
|
||||||
label_visibility="collapsed"
|
|
||||||
)
|
|
||||||
|
|
||||||
selection_mode = c_toggle.toggle("Select to Delete", key="timeline_selection_mode")
|
|
||||||
if not selection_mode:
|
|
||||||
st.session_state.timeline_selected_nodes = set()
|
|
||||||
|
|
||||||
# --- Build sorted node list (shared by all views) ---
|
|
||||||
all_nodes = list(htree.nodes.values())
|
|
||||||
all_nodes.sort(key=lambda x: x["timestamp"], reverse=True)
|
|
||||||
|
|
||||||
# --- MULTISELECT PICKER (shown when selection mode is on) ---
|
|
||||||
if selection_mode:
|
|
||||||
def _fmt_node_option(nid):
|
|
||||||
n = htree.nodes[nid]
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
|
||||||
note = n.get('note', 'Step')
|
|
||||||
head = " (HEAD)" if nid == htree.head_id else ""
|
|
||||||
return f"{note} • {ts} ({nid[:6]}){head}"
|
|
||||||
|
|
||||||
all_ids = [n["id"] for n in all_nodes]
|
|
||||||
current_selection = [nid for nid in all_ids if nid in st.session_state.timeline_selected_nodes]
|
|
||||||
picked = st.multiselect(
|
|
||||||
"Select nodes to delete:",
|
|
||||||
options=all_ids,
|
|
||||||
default=current_selection,
|
|
||||||
format_func=_fmt_node_option,
|
|
||||||
)
|
|
||||||
st.session_state.timeline_selected_nodes = set(picked)
|
|
||||||
|
|
||||||
c_all, c_none, _ = st.columns([1, 1, 4])
|
|
||||||
if c_all.button("Select All", use_container_width=True):
|
|
||||||
st.session_state.timeline_selected_nodes = set(all_ids)
|
|
||||||
st.rerun()
|
|
||||||
if c_none.button("Deselect All", use_container_width=True):
|
|
||||||
st.session_state.timeline_selected_nodes = set()
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- RENDER GRAPH VIEWS ---
|
|
||||||
if view_mode in ["🌳 Horizontal", "🌲 Vertical"]:
|
|
||||||
direction = "LR" if view_mode == "🌳 Horizontal" else "TB"
|
|
||||||
|
|
||||||
if AGRAPH_AVAILABLE:
|
|
||||||
# Interactive graph with streamlit-agraph
|
|
||||||
selected_set = st.session_state.timeline_selected_nodes if selection_mode else set()
|
|
||||||
clicked_node = _render_interactive_graph(htree, direction, selected_set)
|
|
||||||
if clicked_node and clicked_node in htree.nodes:
|
|
||||||
if selection_mode:
|
|
||||||
# Toggle node in selection set
|
|
||||||
if clicked_node in st.session_state.timeline_selected_nodes:
|
|
||||||
st.session_state.timeline_selected_nodes.discard(clicked_node)
|
|
||||||
else:
|
|
||||||
st.session_state.timeline_selected_nodes.add(clicked_node)
|
|
||||||
st.rerun()
|
|
||||||
else:
|
|
||||||
node = htree.nodes[clicked_node]
|
|
||||||
if clicked_node != htree.head_id:
|
|
||||||
_restore_node(data, node, htree, file_path)
|
|
||||||
else:
|
|
||||||
# Fallback to static graphviz
|
|
||||||
try:
|
|
||||||
graph_dot = htree.generate_graph(direction=direction)
|
|
||||||
if direction == "LR":
|
|
||||||
st.graphviz_chart(graph_dot, use_container_width=True)
|
|
||||||
else:
|
|
||||||
_, col_center, _ = st.columns([1, 2, 1])
|
|
||||||
with col_center:
|
|
||||||
st.graphviz_chart(graph_dot, use_container_width=True)
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Graph Error: {e}")
|
|
||||||
st.caption("💡 Install `streamlit-agraph` for interactive click-to-restore")
|
|
||||||
|
|
||||||
# --- RENDER LINEAR LOG VIEW ---
|
|
||||||
elif view_mode == "📜 Linear Log":
|
|
||||||
st.caption("A simple chronological list of all snapshots.")
|
|
||||||
|
|
||||||
for n in all_nodes:
|
|
||||||
is_head = (n["id"] == htree.head_id)
|
|
||||||
with st.container():
|
|
||||||
if selection_mode:
|
|
||||||
c0, c1, c2, c3 = st.columns([0.3, 0.5, 4, 1])
|
|
||||||
with c0:
|
|
||||||
is_selected = n["id"] in st.session_state.timeline_selected_nodes
|
|
||||||
if st.checkbox("", value=is_selected, key=f"log_sel_{n['id']}", label_visibility="collapsed"):
|
|
||||||
st.session_state.timeline_selected_nodes.add(n["id"])
|
|
||||||
else:
|
|
||||||
st.session_state.timeline_selected_nodes.discard(n["id"])
|
|
||||||
else:
|
|
||||||
c1, c2, c3 = st.columns([0.5, 4, 1])
|
|
||||||
with c1:
|
|
||||||
st.markdown("### 📍" if is_head else "### ⚫")
|
|
||||||
with c2:
|
|
||||||
note_txt = n.get('note', 'Step')
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
|
||||||
if is_head:
|
|
||||||
st.markdown(f"**{note_txt}** (Current)")
|
|
||||||
else:
|
|
||||||
st.write(f"**{note_txt}**")
|
|
||||||
st.caption(f"ID: {n['id'][:6]} • {ts}")
|
|
||||||
with c3:
|
|
||||||
if not is_head and not selection_mode:
|
|
||||||
if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"):
|
|
||||||
_restore_node(data, n, htree, file_path)
|
|
||||||
st.divider()
|
|
||||||
|
|
||||||
# --- BATCH DELETE UI ---
|
|
||||||
if selection_mode and st.session_state.timeline_selected_nodes:
|
|
||||||
# Prune any selected IDs that no longer exist in the tree
|
|
||||||
valid_selected = st.session_state.timeline_selected_nodes & set(htree.nodes.keys())
|
|
||||||
st.session_state.timeline_selected_nodes = valid_selected
|
|
||||||
count = len(valid_selected)
|
|
||||||
if count > 0:
|
|
||||||
st.warning(f"**{count}** node{'s' if count != 1 else ''} selected for deletion.")
|
|
||||||
if st.button(f"🗑️ Delete {count} Node{'s' if count != 1 else ''}", type="primary"):
|
|
||||||
# Backup
|
|
||||||
if "history_tree_backup" not in data:
|
|
||||||
data["history_tree_backup"] = []
|
|
||||||
data["history_tree_backup"].append(copy.deepcopy(htree.to_dict()))
|
|
||||||
# Delete all selected nodes
|
|
||||||
for nid in valid_selected:
|
|
||||||
if nid in htree.nodes:
|
|
||||||
del htree.nodes[nid]
|
|
||||||
# Clean up branch tips
|
|
||||||
for b, tip in list(htree.branches.items()):
|
|
||||||
if tip in valid_selected:
|
|
||||||
del htree.branches[b]
|
|
||||||
# Reassign HEAD if deleted
|
|
||||||
if htree.head_id in valid_selected:
|
|
||||||
if htree.nodes:
|
|
||||||
fallback = sorted(htree.nodes.values(), key=lambda x: x["timestamp"])[-1]
|
|
||||||
htree.head_id = fallback["id"]
|
|
||||||
else:
|
|
||||||
htree.head_id = None
|
|
||||||
# Save and reset
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.timeline_selected_nodes = set()
|
|
||||||
st.toast(f"Deleted {count} node{'s' if count != 1 else ''}!", icon="🗑️")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
st.markdown("---")
|
|
||||||
|
|
||||||
# --- NODE SELECTOR ---
|
|
||||||
col_sel, col_act = st.columns([3, 1])
|
|
||||||
|
|
||||||
def fmt_node(n):
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
|
||||||
return f"{n.get('note', 'Step')} • {ts} ({n['id'][:6]})"
|
|
||||||
|
|
||||||
with col_sel:
|
|
||||||
current_idx = 0
|
|
||||||
for i, n in enumerate(all_nodes):
|
|
||||||
if n["id"] == htree.head_id:
|
|
||||||
current_idx = i
|
|
||||||
break
|
|
||||||
|
|
||||||
selected_node = st.selectbox(
|
|
||||||
"Select Version to Manage:",
|
|
||||||
all_nodes,
|
|
||||||
format_func=fmt_node,
|
|
||||||
index=current_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
if not selected_node:
|
|
||||||
return
|
|
||||||
|
|
||||||
node_data = selected_node["data"]
|
|
||||||
|
|
||||||
# --- RESTORE ---
|
|
||||||
with col_act:
|
|
||||||
st.write(""); st.write("")
|
|
||||||
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
|
|
||||||
_restore_node(data, selected_node, htree, file_path)
|
|
||||||
|
|
||||||
# --- RENAME ---
|
|
||||||
rn_col1, rn_col2 = st.columns([3, 1])
|
|
||||||
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
|
|
||||||
if rn_col2.button("Update Label"):
|
|
||||||
selected_node["note"] = new_label
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- DANGER ZONE ---
|
|
||||||
st.markdown("---")
|
|
||||||
with st.expander("⚠️ Danger Zone (Delete)"):
|
|
||||||
st.warning("Deleting a node cannot be undone.")
|
|
||||||
if st.button("🗑️ Delete This Node", type="primary"):
|
|
||||||
if selected_node['id'] in htree.nodes:
|
|
||||||
if "history_tree_backup" not in data:
|
|
||||||
data["history_tree_backup"] = []
|
|
||||||
data["history_tree_backup"].append(copy.deepcopy(htree.to_dict()))
|
|
||||||
del htree.nodes[selected_node['id']]
|
|
||||||
for b, tip in list(htree.branches.items()):
|
|
||||||
if tip == selected_node['id']:
|
|
||||||
del htree.branches[b]
|
|
||||||
if htree.head_id == selected_node['id']:
|
|
||||||
if htree.nodes:
|
|
||||||
fallback = sorted(htree.nodes.values(), key=lambda x: x["timestamp"])[-1]
|
|
||||||
htree.head_id = fallback["id"]
|
|
||||||
else:
|
|
||||||
htree.head_id = None
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.toast("Node Deleted", icon="🗑️")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# --- DATA PREVIEW ---
|
|
||||||
st.markdown("---")
|
|
||||||
with st.expander("🔍 Data Preview", expanded=False):
|
|
||||||
batch_list = node_data.get(KEY_BATCH_DATA, [])
|
|
||||||
|
|
||||||
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
|
|
||||||
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")
|
|
||||||
for i, seq_data in enumerate(batch_list):
|
|
||||||
seq_num = seq_data.get("sequence_number", i + 1)
|
|
||||||
with st.expander(f"🎬 Sequence #{seq_num}", expanded=(i == 0)):
|
|
||||||
prefix = f"p_{selected_node['id']}_s{i}"
|
|
||||||
_render_preview_fields(seq_data, prefix)
|
|
||||||
else:
|
|
||||||
prefix = f"p_{selected_node['id']}_single"
|
|
||||||
_render_preview_fields(node_data, prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def _render_interactive_graph(htree, direction, selected_nodes=None):
|
|
||||||
"""Render an interactive graph using streamlit-agraph. Returns clicked node id."""
|
|
||||||
if selected_nodes is None:
|
|
||||||
selected_nodes = set()
|
|
||||||
|
|
||||||
# Build reverse lookup: branch tip -> branch name(s)
|
|
||||||
tip_to_branches = {}
|
|
||||||
for b_name, tip_id in htree.branches.items():
|
|
||||||
if tip_id:
|
|
||||||
tip_to_branches.setdefault(tip_id, []).append(b_name)
|
|
||||||
|
|
||||||
sorted_nodes_list = sorted(htree.nodes.values(), key=lambda x: x["timestamp"])
|
|
||||||
|
|
||||||
nodes = []
|
|
||||||
edges = []
|
|
||||||
|
|
||||||
for n in sorted_nodes_list:
|
|
||||||
nid = n["id"]
|
|
||||||
full_note = n.get('note', 'Step')
|
|
||||||
display_note = (full_note[:20] + '..') if len(full_note) > 20 else full_note
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
|
||||||
|
|
||||||
# Branch label
|
|
||||||
branch_label = ""
|
|
||||||
if nid in tip_to_branches:
|
|
||||||
branch_label = f"\n[{', '.join(tip_to_branches[nid])}]"
|
|
||||||
|
|
||||||
label = f"{display_note}\n{ts}{branch_label}"
|
|
||||||
|
|
||||||
# Colors - selected nodes override to red
|
|
||||||
if nid in selected_nodes:
|
|
||||||
color = "#ff5555" # Selected for deletion - red
|
|
||||||
elif nid == htree.head_id:
|
|
||||||
color = "#ffdd44" # Current head - bright yellow
|
|
||||||
elif nid in htree.branches.values():
|
|
||||||
color = "#66dd66" # Branch tip - bright green
|
|
||||||
else:
|
|
||||||
color = "#aaccff" # Normal - light blue
|
|
||||||
|
|
||||||
nodes.append(Node(
|
|
||||||
id=nid,
|
|
||||||
label=label,
|
|
||||||
size=20,
|
|
||||||
color=color,
|
|
||||||
font={"size": 10, "color": "#ffffff"}
|
|
||||||
))
|
|
||||||
|
|
||||||
if n["parent"] and n["parent"] in htree.nodes:
|
|
||||||
edges.append(Edge(source=n["parent"], target=nid, color="#888888"))
|
|
||||||
|
|
||||||
# Config based on direction
|
|
||||||
is_horizontal = direction == "LR"
|
|
||||||
config = Config(
|
|
||||||
width="100%",
|
|
||||||
height=400 if is_horizontal else 600,
|
|
||||||
directed=True,
|
|
||||||
hierarchical=True,
|
|
||||||
physics=False,
|
|
||||||
nodeHighlightBehavior=True,
|
|
||||||
highlightColor="#ffcc00",
|
|
||||||
collapsible=False,
|
|
||||||
layout={
|
|
||||||
"hierarchical": {
|
|
||||||
"enabled": True,
|
|
||||||
"direction": "LR" if is_horizontal else "UD",
|
|
||||||
"sortMethod": "directed",
|
|
||||||
"levelSeparation": 150 if is_horizontal else 80,
|
|
||||||
"nodeSpacing": 100 if is_horizontal else 60,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return agraph(nodes=nodes, edges=edges, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_node(data, node, htree, file_path):
|
|
||||||
"""Restore a history node as the current version."""
|
|
||||||
node_data = node["data"]
|
|
||||||
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
|
|
||||||
del data[KEY_BATCH_DATA]
|
|
||||||
data.update(node_data)
|
|
||||||
htree.head_id = node['id']
|
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
|
||||||
save_json(file_path, data)
|
|
||||||
st.session_state.ui_reset_token += 1
|
|
||||||
label = f"{node.get('note')} ({node['id'][:4]})"
|
|
||||||
st.session_state.restored_indicator = label
|
|
||||||
st.toast("Restored!", icon="🔄")
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
|
|
||||||
def _render_preview_fields(item_data, prefix):
|
|
||||||
"""Render a read-only preview of prompts, settings, and LoRAs."""
|
|
||||||
# Prompts
|
|
||||||
p_col1, p_col2 = st.columns(2)
|
|
||||||
with p_col1:
|
|
||||||
st.text_area("General Positive", value=item_data.get("general_prompt", ""), height=80, disabled=True, key=f"{prefix}_gp")
|
|
||||||
val_sp = item_data.get("current_prompt", "") or item_data.get("prompt", "")
|
|
||||||
st.text_area("Specific Positive", value=val_sp, height=80, disabled=True, key=f"{prefix}_sp")
|
|
||||||
with p_col2:
|
|
||||||
st.text_area("General Negative", value=item_data.get("general_negative", ""), height=80, disabled=True, key=f"{prefix}_gn")
|
|
||||||
st.text_area("Specific Negative", value=item_data.get("negative", ""), height=80, disabled=True, key=f"{prefix}_sn")
|
|
||||||
|
|
||||||
# Settings
|
|
||||||
s_col1, s_col2, s_col3 = st.columns(3)
|
|
||||||
s_col1.text_input("Camera", value=str(item_data.get("camera", "static")), disabled=True, key=f"{prefix}_cam")
|
|
||||||
s_col2.text_input("FLF", value=str(item_data.get("flf", "0.0")), disabled=True, key=f"{prefix}_flf")
|
|
||||||
s_col3.text_input("Seed", value=str(item_data.get("seed", "-1")), disabled=True, key=f"{prefix}_seed")
|
|
||||||
|
|
||||||
# LoRAs
|
|
||||||
with st.expander("💊 LoRA Configuration", expanded=False):
|
|
||||||
l1, l2, l3 = st.columns(3)
|
|
||||||
with l1:
|
|
||||||
st.text_input("L1 Name", value=item_data.get("lora 1 high", ""), disabled=True, key=f"{prefix}_l1h")
|
|
||||||
st.text_input("L1 Str", value=str(item_data.get("lora 1 low", "")), disabled=True, key=f"{prefix}_l1l")
|
|
||||||
with l2:
|
|
||||||
st.text_input("L2 Name", value=item_data.get("lora 2 high", ""), disabled=True, key=f"{prefix}_l2h")
|
|
||||||
st.text_input("L2 Str", value=str(item_data.get("lora 2 low", "")), disabled=True, key=f"{prefix}_l2l")
|
|
||||||
with l3:
|
|
||||||
st.text_input("L3 Name", value=item_data.get("lora 3 high", ""), disabled=True, key=f"{prefix}_l3h")
|
|
||||||
st.text_input("L3 Str", value=str(item_data.get("lora 3 low", "")), disabled=True, key=f"{prefix}_l3l")
|
|
||||||
|
|
||||||
# VACE
|
|
||||||
vace_keys = ["frame_to_skip", "vace schedule", "video file path"]
|
|
||||||
if any(k in item_data for k in vace_keys):
|
|
||||||
with st.expander("🎞️ VACE / I2V Settings", expanded=False):
|
|
||||||
v1, v2, v3 = st.columns(3)
|
|
||||||
v1.text_input("Skip Frames", value=str(item_data.get("frame_to_skip", "")), disabled=True, key=f"{prefix}_fts")
|
|
||||||
v2.text_input("Schedule", value=str(item_data.get("vace schedule", "")), disabled=True, key=f"{prefix}_vsc")
|
|
||||||
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
|
|
||||||
@@ -5,7 +5,7 @@ from nicegui import ui
|
|||||||
|
|
||||||
from state import AppState
|
from state import AppState
|
||||||
from history_tree import HistoryTree
|
from history_tree import HistoryTree
|
||||||
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
from utils import save_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||||
|
|
||||||
|
|
||||||
def _delete_nodes(htree, data, file_path, node_ids):
|
def _delete_nodes(htree, data, file_path, node_ids):
|
||||||
@@ -64,14 +64,16 @@ def _render_selection_picker(all_nodes, htree, state, refresh_fn):
|
|||||||
|
|
||||||
|
|
||||||
def _render_graph_or_log(mode, all_nodes, htree, selected_nodes,
|
def _render_graph_or_log(mode, all_nodes, htree, selected_nodes,
|
||||||
selection_mode_on, toggle_select_fn, restore_fn):
|
selection_mode_on, toggle_select_fn, restore_fn,
|
||||||
|
selected=None):
|
||||||
"""Render graph visualization or linear log view."""
|
"""Render graph visualization or linear log view."""
|
||||||
if mode in ('Horizontal', 'Vertical'):
|
if mode in ('Horizontal', 'Vertical'):
|
||||||
direction = 'LR' if mode == 'Horizontal' else 'TB'
|
direction = 'LR' if mode == 'Horizontal' else 'TB'
|
||||||
with ui.card().classes('w-full q-pa-md'):
|
with ui.card().classes('w-full q-pa-md'):
|
||||||
try:
|
try:
|
||||||
graph_dot = htree.generate_graph(direction=direction)
|
graph_dot = htree.generate_graph(direction=direction)
|
||||||
_render_graphviz(graph_dot)
|
sel_id = selected.get('node_id') if selected else None
|
||||||
|
_render_graphviz(graph_dot, selected_node_id=sel_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ui.label(f'Graph Error: {e}').classes('text-negative')
|
ui.label(f'Graph Error: {e}').classes('text-negative')
|
||||||
|
|
||||||
@@ -132,6 +134,8 @@ def _render_batch_delete(htree, data, file_path, state, refresh_fn):
|
|||||||
def do_batch_delete():
|
def do_batch_delete():
|
||||||
current_valid = state.timeline_selected_nodes & set(htree.nodes.keys())
|
current_valid = state.timeline_selected_nodes & set(htree.nodes.keys())
|
||||||
_delete_nodes(htree, data, file_path, current_valid)
|
_delete_nodes(htree, data, file_path, current_valid)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
state.timeline_selected_nodes = set()
|
state.timeline_selected_nodes = set()
|
||||||
ui.notify(
|
ui.notify(
|
||||||
f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!',
|
f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!',
|
||||||
@@ -165,28 +169,24 @@ def _find_active_branch(htree):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn):
|
def _find_branch_for_node(htree, node_id):
|
||||||
|
"""Return the branch name whose ancestry contains node_id, or None."""
|
||||||
|
for b_name, tip_id in htree.branches.items():
|
||||||
|
current = tip_id
|
||||||
|
while current and current in htree.nodes:
|
||||||
|
if current == node_id:
|
||||||
|
return b_name
|
||||||
|
current = htree.nodes[current].get('parent')
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn,
|
||||||
|
selected, state=None):
|
||||||
"""Render branch-grouped node manager with restore, rename, delete, and preview."""
|
"""Render branch-grouped node manager with restore, rename, delete, and preview."""
|
||||||
ui.label('Manage Version').classes('section-header')
|
ui.label('Manage Version').classes('section-header')
|
||||||
|
|
||||||
# --- State that survives @ui.refreshable ---
|
|
||||||
active_branch = _find_active_branch(htree)
|
active_branch = _find_active_branch(htree)
|
||||||
|
|
||||||
# Default branch: active branch, or branch whose ancestry contains HEAD
|
|
||||||
default_branch = active_branch
|
|
||||||
if not default_branch and htree.head_id:
|
|
||||||
for b_name, tip_id in htree.branches.items():
|
|
||||||
for n in _walk_branch_nodes(htree, tip_id):
|
|
||||||
if n['id'] == htree.head_id:
|
|
||||||
default_branch = b_name
|
|
||||||
break
|
|
||||||
if default_branch:
|
|
||||||
break
|
|
||||||
if not default_branch and htree.branches:
|
|
||||||
default_branch = next(iter(htree.branches))
|
|
||||||
|
|
||||||
selected = {'node_id': htree.head_id, 'branch': default_branch}
|
|
||||||
|
|
||||||
# --- (a) Branch selector ---
|
# --- (a) Branch selector ---
|
||||||
def fmt_branch(b_name):
|
def fmt_branch(b_name):
|
||||||
count = len(_walk_branch_nodes(htree, htree.branches.get(b_name)))
|
count = len(_walk_branch_nodes(htree, htree.branches.get(b_name)))
|
||||||
@@ -293,6 +293,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
|
|||||||
htree.nodes[sel_id]['note'] = rename_input.value
|
htree.nodes[sel_id]['note'] = rename_input.value
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state and state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
ui.notify('Label updated', type='positive')
|
ui.notify('Label updated', type='positive')
|
||||||
refresh_fn()
|
refresh_fn()
|
||||||
|
|
||||||
@@ -306,6 +308,8 @@ def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_
|
|||||||
def delete_selected():
|
def delete_selected():
|
||||||
if sel_id in htree.nodes:
|
if sel_id in htree.nodes:
|
||||||
_delete_nodes(htree, data, file_path, {sel_id})
|
_delete_nodes(htree, data, file_path, {sel_id})
|
||||||
|
if state and state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
ui.notify('Node Deleted', type='positive')
|
ui.notify('Node Deleted', type='positive')
|
||||||
refresh_fn()
|
refresh_fn()
|
||||||
|
|
||||||
@@ -331,6 +335,21 @@ def render_timeline_tab(state: AppState):
|
|||||||
|
|
||||||
htree = HistoryTree(tree_data)
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
|
# --- Shared selected-node state (survives refreshes, shared by graph + manager) ---
|
||||||
|
active_branch = _find_active_branch(htree)
|
||||||
|
default_branch = active_branch
|
||||||
|
if not default_branch and htree.head_id:
|
||||||
|
for b_name, tip_id in htree.branches.items():
|
||||||
|
for n in _walk_branch_nodes(htree, tip_id):
|
||||||
|
if n['id'] == htree.head_id:
|
||||||
|
default_branch = b_name
|
||||||
|
break
|
||||||
|
if default_branch:
|
||||||
|
break
|
||||||
|
if not default_branch and htree.branches:
|
||||||
|
default_branch = next(iter(htree.branches))
|
||||||
|
selected = {'node_id': htree.head_id, 'branch': default_branch}
|
||||||
|
|
||||||
if state.restored_indicator:
|
if state.restored_indicator:
|
||||||
ui.label(f'Editing Restored Version: {state.restored_indicator}').classes(
|
ui.label(f'Editing Restored Version: {state.restored_indicator}').classes(
|
||||||
'text-info q-pa-sm')
|
'text-info q-pa-sm')
|
||||||
@@ -354,7 +373,8 @@ def render_timeline_tab(state: AppState):
|
|||||||
|
|
||||||
_render_graph_or_log(
|
_render_graph_or_log(
|
||||||
view_mode.value, all_nodes, htree, selected_nodes,
|
view_mode.value, all_nodes, htree, selected_nodes,
|
||||||
selection_mode.value, _toggle_select, _restore_and_refresh)
|
selection_mode.value, _toggle_select, _restore_and_refresh,
|
||||||
|
selected=selected)
|
||||||
|
|
||||||
if selection_mode.value and state.timeline_selected_nodes:
|
if selection_mode.value and state.timeline_selected_nodes:
|
||||||
_render_batch_delete(htree, data, file_path, state, render_timeline.refresh)
|
_render_batch_delete(htree, data, file_path, state, render_timeline.refresh)
|
||||||
@@ -362,7 +382,8 @@ def render_timeline_tab(state: AppState):
|
|||||||
with ui.card().classes('w-full q-pa-md q-mt-md'):
|
with ui.card().classes('w-full q-pa-md q-mt-md'):
|
||||||
_render_node_manager(
|
_render_node_manager(
|
||||||
all_nodes, htree, data, file_path,
|
all_nodes, htree, data, file_path,
|
||||||
_restore_and_refresh, render_timeline.refresh)
|
_restore_and_refresh, render_timeline.refresh,
|
||||||
|
selected, state=state)
|
||||||
|
|
||||||
def _toggle_select(nid, checked):
|
def _toggle_select(nid, checked):
|
||||||
if checked:
|
if checked:
|
||||||
@@ -380,14 +401,87 @@ def render_timeline_tab(state: AppState):
|
|||||||
selection_mode.on_value_change(lambda _: render_timeline.refresh())
|
selection_mode.on_value_change(lambda _: render_timeline.refresh())
|
||||||
render_timeline()
|
render_timeline()
|
||||||
|
|
||||||
|
# --- Poll for graph node clicks (JS → Python bridge) ---
|
||||||
|
async def _poll_graph_click():
|
||||||
|
if view_mode.value == 'Linear Log':
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = await ui.run_javascript(
|
||||||
|
'const v = window.graphSelectedNode;'
|
||||||
|
'window.graphSelectedNode = null; v;'
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if not result:
|
||||||
|
return
|
||||||
|
node_id = str(result)
|
||||||
|
if node_id not in htree.nodes:
|
||||||
|
return
|
||||||
|
branch = _find_branch_for_node(htree, node_id)
|
||||||
|
if branch:
|
||||||
|
selected['branch'] = branch
|
||||||
|
selected['node_id'] = node_id
|
||||||
|
render_timeline.refresh()
|
||||||
|
|
||||||
def _render_graphviz(dot_source: str):
|
ui.timer(0.2, _poll_graph_click)
|
||||||
"""Render graphviz DOT source as SVG using ui.html."""
|
|
||||||
|
|
||||||
|
def _render_graphviz(dot_source: str, selected_node_id: str | None = None):
|
||||||
|
"""Render graphviz DOT source as interactive SVG with click-to-select."""
|
||||||
try:
|
try:
|
||||||
import graphviz
|
import graphviz
|
||||||
src = graphviz.Source(dot_source)
|
src = graphviz.Source(dot_source)
|
||||||
svg = src.pipe(format='svg').decode('utf-8')
|
svg = src.pipe(format='svg').decode('utf-8')
|
||||||
ui.html(f'<div style="overflow-x: auto;">{svg}</div>')
|
|
||||||
|
sel_escaped = selected_node_id.replace("'", "\\'") if selected_node_id else ''
|
||||||
|
|
||||||
|
# CSS inline (allowed), JS via run_javascript (script tags blocked)
|
||||||
|
css = '''<style>
|
||||||
|
.timeline-graph g.node { cursor: pointer; }
|
||||||
|
.timeline-graph g.node:hover { filter: brightness(1.3); }
|
||||||
|
.timeline-graph g.node.selected ellipse,
|
||||||
|
.timeline-graph g.node.selected polygon[stroke]:not([stroke="none"]) {
|
||||||
|
stroke: #f59e0b !important;
|
||||||
|
stroke-width: 3px !important;
|
||||||
|
}
|
||||||
|
</style>'''
|
||||||
|
|
||||||
|
ui.html(
|
||||||
|
f'{css}<div class="timeline-graph"'
|
||||||
|
f' style="overflow: auto; max-height: 500px; width: 100%;">'
|
||||||
|
f'{svg}</div>'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find container by class with retry for Vue async render
|
||||||
|
ui.run_javascript(f'''
|
||||||
|
(function attempt(tries) {{
|
||||||
|
var container = document.querySelector('.timeline-graph');
|
||||||
|
if (!container || !container.querySelector('g.node')) {{
|
||||||
|
if (tries < 20) setTimeout(function() {{ attempt(tries + 1); }}, 100);
|
||||||
|
return;
|
||||||
|
}}
|
||||||
|
container.querySelectorAll('g.node').forEach(function(g) {{
|
||||||
|
g.addEventListener('click', function() {{
|
||||||
|
var title = g.querySelector('title');
|
||||||
|
if (title) {{
|
||||||
|
window.graphSelectedNode = title.textContent.trim();
|
||||||
|
container.querySelectorAll('g.node.selected').forEach(
|
||||||
|
function(el) {{ el.classList.remove('selected'); }});
|
||||||
|
g.classList.add('selected');
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
}});
|
||||||
|
var selId = '{sel_escaped}';
|
||||||
|
if (selId) {{
|
||||||
|
container.querySelectorAll('g.node').forEach(function(g) {{
|
||||||
|
var title = g.querySelector('title');
|
||||||
|
if (title && title.textContent.trim() === selId) {{
|
||||||
|
g.classList.add('selected');
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
}}
|
||||||
|
}})(0);
|
||||||
|
''')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ui.label('Install graphviz Python package for graph rendering.').classes('text-warning')
|
ui.label('Install graphviz Python package for graph rendering.').classes('text-warning')
|
||||||
ui.code(dot_source).classes('w-full')
|
ui.code(dot_source).classes('w-full')
|
||||||
@@ -404,6 +498,8 @@ def _restore_node(data, node, htree, file_path, state: AppState):
|
|||||||
htree.head_id = node['id']
|
htree.head_id = node['id']
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
if state.db_enabled and state.current_project and state.db:
|
||||||
|
sync_to_db(state.db, state.current_project, file_path, data)
|
||||||
label = f"{node.get('note', 'Step')} ({node['id'][:4]})"
|
label = f"{node.get('note', 'Step')} ({node['id'][:4]})"
|
||||||
state.restored_indicator = label
|
state.restored_indicator = label
|
||||||
ui.notify('Restored!', type='positive')
|
ui.notify('Restored!', type='positive')
|
||||||
|
|||||||
369
tests/test_db.py
Normal file
369
tests/test_db.py
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from db import ProjectDB
|
||||||
|
from utils import KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db(tmp_path):
|
||||||
|
"""Create a fresh ProjectDB in a temp directory."""
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
pdb = ProjectDB(db_path)
|
||||||
|
yield pdb
|
||||||
|
pdb.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Projects CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestProjects:
|
||||||
|
def test_create_and_get(self, db):
|
||||||
|
pid = db.create_project("proj1", "/some/path", "A test project")
|
||||||
|
assert pid > 0
|
||||||
|
proj = db.get_project("proj1")
|
||||||
|
assert proj is not None
|
||||||
|
assert proj["name"] == "proj1"
|
||||||
|
assert proj["folder_path"] == "/some/path"
|
||||||
|
assert proj["description"] == "A test project"
|
||||||
|
|
||||||
|
def test_list_projects(self, db):
|
||||||
|
db.create_project("beta", "/b")
|
||||||
|
db.create_project("alpha", "/a")
|
||||||
|
projects = db.list_projects()
|
||||||
|
assert len(projects) == 2
|
||||||
|
assert projects[0]["name"] == "alpha"
|
||||||
|
assert projects[1]["name"] == "beta"
|
||||||
|
|
||||||
|
def test_get_nonexistent(self, db):
|
||||||
|
assert db.get_project("nope") is None
|
||||||
|
|
||||||
|
def test_delete_project(self, db):
|
||||||
|
db.create_project("to_delete", "/x")
|
||||||
|
assert db.delete_project("to_delete") is True
|
||||||
|
assert db.get_project("to_delete") is None
|
||||||
|
|
||||||
|
def test_delete_nonexistent(self, db):
|
||||||
|
assert db.delete_project("nope") is False
|
||||||
|
|
||||||
|
def test_unique_name_constraint(self, db):
|
||||||
|
db.create_project("dup", "/a")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
db.create_project("dup", "/b")
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Data files
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDataFiles:
|
||||||
|
def test_create_and_list(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch_i2v", "i2v", {"extra": "meta"})
|
||||||
|
assert df_id > 0
|
||||||
|
files = db.list_data_files(pid)
|
||||||
|
assert len(files) == 1
|
||||||
|
assert files[0]["name"] == "batch_i2v"
|
||||||
|
assert files[0]["data_type"] == "i2v"
|
||||||
|
|
||||||
|
def test_get_data_file(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
db.create_data_file(pid, "batch_i2v", "i2v", {"key": "value"})
|
||||||
|
df = db.get_data_file(pid, "batch_i2v")
|
||||||
|
assert df is not None
|
||||||
|
assert df["top_level"] == {"key": "value"}
|
||||||
|
|
||||||
|
def test_get_data_file_by_names(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
db.create_data_file(pid, "batch_i2v", "i2v")
|
||||||
|
df = db.get_data_file_by_names("p1", "batch_i2v")
|
||||||
|
assert df is not None
|
||||||
|
assert df["name"] == "batch_i2v"
|
||||||
|
|
||||||
|
def test_get_nonexistent_data_file(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
assert db.get_data_file(pid, "nope") is None
|
||||||
|
|
||||||
|
def test_unique_constraint(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
db.create_data_file(pid, "batch_i2v", "i2v")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
db.create_data_file(pid, "batch_i2v", "vace")
|
||||||
|
|
||||||
|
def test_cascade_delete(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "hello"})
|
||||||
|
db.save_history_tree(df_id, {"nodes": {}})
|
||||||
|
db.delete_project("p1")
|
||||||
|
assert db.get_data_file(pid, "batch_i2v") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Sequences
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSequences:
|
||||||
|
def test_upsert_and_get(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "hello", "seed": 42})
|
||||||
|
data = db.get_sequence(df_id, 1)
|
||||||
|
assert data == {"prompt": "hello", "seed": 42}
|
||||||
|
|
||||||
|
def test_upsert_updates_existing(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "v1"})
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "v2"})
|
||||||
|
data = db.get_sequence(df_id, 1)
|
||||||
|
assert data["prompt"] == "v2"
|
||||||
|
|
||||||
|
def test_list_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 3, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 1, {"b": 2})
|
||||||
|
db.upsert_sequence(df_id, 2, {"c": 3})
|
||||||
|
seqs = db.list_sequences(df_id)
|
||||||
|
assert seqs == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_get_nonexistent_sequence(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
assert db.get_sequence(df_id, 99) is None
|
||||||
|
|
||||||
|
def test_get_sequence_keys(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {
|
||||||
|
"prompt": "hello",
|
||||||
|
"seed": 42,
|
||||||
|
"cfg": 1.5,
|
||||||
|
"flag": True,
|
||||||
|
})
|
||||||
|
keys, types = db.get_sequence_keys(df_id, 1)
|
||||||
|
assert "prompt" in keys
|
||||||
|
assert "seed" in keys
|
||||||
|
idx_prompt = keys.index("prompt")
|
||||||
|
idx_seed = keys.index("seed")
|
||||||
|
idx_cfg = keys.index("cfg")
|
||||||
|
idx_flag = keys.index("flag")
|
||||||
|
assert types[idx_prompt] == "STRING"
|
||||||
|
assert types[idx_seed] == "INT"
|
||||||
|
assert types[idx_cfg] == "FLOAT"
|
||||||
|
assert types[idx_flag] == "STRING" # bools -> STRING
|
||||||
|
|
||||||
|
def test_get_sequence_keys_nonexistent(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
keys, types = db.get_sequence_keys(df_id, 99)
|
||||||
|
assert keys == []
|
||||||
|
assert types == []
|
||||||
|
|
||||||
|
def test_delete_sequences_for_file(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||||
|
db.delete_sequences_for_file(df_id)
|
||||||
|
assert db.list_sequences(df_id) == []
|
||||||
|
|
||||||
|
def test_count_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
assert db.count_sequences(df_id) == 0
|
||||||
|
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||||
|
db.upsert_sequence(df_id, 3, {"c": 3})
|
||||||
|
assert db.count_sequences(df_id) == 3
|
||||||
|
|
||||||
|
def test_query_total_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {"a": 1})
|
||||||
|
db.upsert_sequence(df_id, 2, {"b": 2})
|
||||||
|
assert db.query_total_sequences("p1", "batch") == 2
|
||||||
|
|
||||||
|
def test_query_total_sequences_nonexistent(self, db):
|
||||||
|
assert db.query_total_sequences("nope", "nope") == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# History trees
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestHistoryTrees:
|
||||||
|
def test_save_and_get(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
tree = {"nodes": {"abc": {"id": "abc"}}, "head_id": "abc"}
|
||||||
|
db.save_history_tree(df_id, tree)
|
||||||
|
result = db.get_history_tree(df_id)
|
||||||
|
assert result == tree
|
||||||
|
|
||||||
|
def test_upsert_updates(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.save_history_tree(df_id, {"v": 1})
|
||||||
|
db.save_history_tree(df_id, {"v": 2})
|
||||||
|
result = db.get_history_tree(df_id)
|
||||||
|
assert result == {"v": 2}
|
||||||
|
|
||||||
|
def test_get_nonexistent(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
assert db.get_history_tree(df_id) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Import
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestImport:
|
||||||
|
def test_import_json_file(self, db, tmp_path):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "batch_prompt_i2v.json"
|
||||||
|
data = {
|
||||||
|
KEY_BATCH_DATA: [
|
||||||
|
{"sequence_number": 1, "prompt": "hello", "seed": 42},
|
||||||
|
{"sequence_number": 2, "prompt": "world", "seed": 99},
|
||||||
|
],
|
||||||
|
KEY_HISTORY_TREE: {"nodes": {}, "head_id": None},
|
||||||
|
}
|
||||||
|
json_path.write_text(json.dumps(data))
|
||||||
|
|
||||||
|
df_id = db.import_json_file(pid, json_path, "i2v")
|
||||||
|
assert df_id > 0
|
||||||
|
|
||||||
|
seqs = db.list_sequences(df_id)
|
||||||
|
assert seqs == [1, 2]
|
||||||
|
|
||||||
|
s1 = db.get_sequence(df_id, 1)
|
||||||
|
assert s1["prompt"] == "hello"
|
||||||
|
assert s1["seed"] == 42
|
||||||
|
|
||||||
|
tree = db.get_history_tree(df_id)
|
||||||
|
assert tree == {"nodes": {}, "head_id": None}
|
||||||
|
|
||||||
|
def test_import_file_name_from_stem(self, db, tmp_path):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "my_batch.json"
|
||||||
|
json_path.write_text(json.dumps({KEY_BATCH_DATA: [{"sequence_number": 1}]}))
|
||||||
|
db.import_json_file(pid, json_path)
|
||||||
|
df = db.get_data_file(pid, "my_batch")
|
||||||
|
assert df is not None
|
||||||
|
|
||||||
|
def test_import_no_batch_data(self, db, tmp_path):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "simple.json"
|
||||||
|
json_path.write_text(json.dumps({"prompt": "flat file"}))
|
||||||
|
df_id = db.import_json_file(pid, json_path)
|
||||||
|
seqs = db.list_sequences(df_id)
|
||||||
|
assert seqs == []
|
||||||
|
|
||||||
|
def test_reimport_updates_existing(self, db, tmp_path):
|
||||||
|
"""Re-importing the same file should update data, not crash."""
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "batch.json"
|
||||||
|
|
||||||
|
# First import
|
||||||
|
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
|
||||||
|
json_path.write_text(json.dumps(data_v1))
|
||||||
|
df_id_1 = db.import_json_file(pid, json_path, "i2v")
|
||||||
|
|
||||||
|
# Second import (same file, updated data)
|
||||||
|
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
|
||||||
|
json_path.write_text(json.dumps(data_v2))
|
||||||
|
df_id_2 = db.import_json_file(pid, json_path, "vace")
|
||||||
|
|
||||||
|
# Should reuse the same data_file row
|
||||||
|
assert df_id_1 == df_id_2
|
||||||
|
# Data type should be updated
|
||||||
|
df = db.get_data_file(pid, "batch")
|
||||||
|
assert df["data_type"] == "vace"
|
||||||
|
# Sequences should reflect v2
|
||||||
|
seqs = db.list_sequences(df_id_2)
|
||||||
|
assert seqs == [1, 2]
|
||||||
|
s1 = db.get_sequence(df_id_2, 1)
|
||||||
|
assert s1["prompt"] == "v2"
|
||||||
|
|
||||||
|
def test_import_skips_non_dict_batch_items(self, db, tmp_path):
|
||||||
|
"""Non-dict elements in batch_data should be silently skipped, not crash."""
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "mixed.json"
|
||||||
|
data = {KEY_BATCH_DATA: [
|
||||||
|
{"sequence_number": 1, "prompt": "valid"},
|
||||||
|
"not a dict",
|
||||||
|
42,
|
||||||
|
None,
|
||||||
|
{"sequence_number": 3, "prompt": "also valid"},
|
||||||
|
]}
|
||||||
|
json_path.write_text(json.dumps(data))
|
||||||
|
df_id = db.import_json_file(pid, json_path)
|
||||||
|
|
||||||
|
seqs = db.list_sequences(df_id)
|
||||||
|
assert seqs == [1, 3]
|
||||||
|
|
||||||
|
def test_import_atomic_on_error(self, db, tmp_path):
|
||||||
|
"""If import fails partway, no partial data should be committed."""
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "batch.json"
|
||||||
|
data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]}
|
||||||
|
json_path.write_text(json.dumps(data))
|
||||||
|
db.import_json_file(pid, json_path)
|
||||||
|
|
||||||
|
# Now try to import with bad data that will cause an error
|
||||||
|
# (overwrite the file with invalid sequence_number that causes int() to fail)
|
||||||
|
bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]}
|
||||||
|
json_path.write_text(json.dumps(bad_data))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
db.import_json_file(pid, json_path)
|
||||||
|
|
||||||
|
# Original data should still be intact (rollback worked)
|
||||||
|
df = db.get_data_file(pid, "batch")
|
||||||
|
assert df is not None
|
||||||
|
s1 = db.get_sequence(df["id"], 1)
|
||||||
|
assert s1["prompt"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Query helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQueryHelpers:
|
||||||
|
def test_query_sequence_data(self, db):
|
||||||
|
pid = db.create_project("myproject", "/mp")
|
||||||
|
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
||||||
|
result = db.query_sequence_data("myproject", "batch_i2v", 1)
|
||||||
|
assert result == {"prompt": "test", "seed": 7}
|
||||||
|
|
||||||
|
def test_query_sequence_data_not_found(self, db):
|
||||||
|
assert db.query_sequence_data("nope", "nope", 1) is None
|
||||||
|
|
||||||
|
def test_query_sequence_keys(self, db):
|
||||||
|
pid = db.create_project("myproject", "/mp")
|
||||||
|
df_id = db.create_data_file(pid, "batch_i2v", "i2v")
|
||||||
|
db.upsert_sequence(df_id, 1, {"prompt": "test", "seed": 7})
|
||||||
|
keys, types = db.query_sequence_keys("myproject", "batch_i2v", 1)
|
||||||
|
assert "prompt" in keys
|
||||||
|
assert "seed" in keys
|
||||||
|
|
||||||
|
def test_list_project_files(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
db.create_data_file(pid, "file_a", "i2v")
|
||||||
|
db.create_data_file(pid, "file_b", "vace")
|
||||||
|
files = db.list_project_files("p1")
|
||||||
|
assert len(files) == 2
|
||||||
|
|
||||||
|
def test_list_project_sequences(self, db):
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
|
db.upsert_sequence(df_id, 1, {})
|
||||||
|
db.upsert_sequence(df_id, 2, {})
|
||||||
|
seqs = db.list_project_sequences("p1", "batch")
|
||||||
|
assert seqs == [1, 2]
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from json_loader import (
|
|
||||||
to_float, to_int, get_batch_item, read_json_data,
|
|
||||||
JSONLoaderDynamic, MAX_DYNAMIC_OUTPUTS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestToFloat:
|
|
||||||
def test_valid(self):
|
|
||||||
assert to_float("3.14") == 3.14
|
|
||||||
assert to_float(5) == 5.0
|
|
||||||
|
|
||||||
def test_invalid(self):
|
|
||||||
assert to_float("abc") == 0.0
|
|
||||||
|
|
||||||
def test_none(self):
|
|
||||||
assert to_float(None) == 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class TestToInt:
|
|
||||||
def test_valid(self):
|
|
||||||
assert to_int("7") == 7
|
|
||||||
assert to_int(3.9) == 3
|
|
||||||
|
|
||||||
def test_invalid(self):
|
|
||||||
assert to_int("xyz") == 0
|
|
||||||
|
|
||||||
def test_none(self):
|
|
||||||
assert to_int(None) == 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetBatchItem:
|
|
||||||
def test_lookup_by_sequence_number_field(self):
|
|
||||||
data = {"batch_data": [
|
|
||||||
{"sequence_number": 1, "a": "first"},
|
|
||||||
{"sequence_number": 5, "a": "fifth"},
|
|
||||||
{"sequence_number": 3, "a": "third"},
|
|
||||||
]}
|
|
||||||
assert get_batch_item(data, 5) == {"sequence_number": 5, "a": "fifth"}
|
|
||||||
assert get_batch_item(data, 3) == {"sequence_number": 3, "a": "third"}
|
|
||||||
|
|
||||||
def test_fallback_to_index(self):
|
|
||||||
data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]}
|
|
||||||
assert get_batch_item(data, 2) == {"a": 2}
|
|
||||||
|
|
||||||
def test_clamp_high(self):
|
|
||||||
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
|
||||||
assert get_batch_item(data, 99) == {"a": 2}
|
|
||||||
|
|
||||||
def test_clamp_low(self):
|
|
||||||
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
|
||||||
assert get_batch_item(data, 0) == {"a": 1}
|
|
||||||
|
|
||||||
def test_no_batch_data(self):
|
|
||||||
data = {"key": "val"}
|
|
||||||
assert get_batch_item(data, 1) == data
|
|
||||||
|
|
||||||
|
|
||||||
class TestReadJsonData:
|
|
||||||
def test_missing_file(self, tmp_path):
|
|
||||||
assert read_json_data(str(tmp_path / "nope.json")) == {}
|
|
||||||
|
|
||||||
def test_invalid_json(self, tmp_path):
|
|
||||||
p = tmp_path / "bad.json"
|
|
||||||
p.write_text("{broken")
|
|
||||||
assert read_json_data(str(p)) == {}
|
|
||||||
|
|
||||||
def test_non_dict_json(self, tmp_path):
|
|
||||||
p = tmp_path / "list.json"
|
|
||||||
p.write_text(json.dumps([1, 2, 3]))
|
|
||||||
assert read_json_data(str(p)) == {}
|
|
||||||
|
|
||||||
def test_valid(self, tmp_path):
|
|
||||||
p = tmp_path / "ok.json"
|
|
||||||
p.write_text(json.dumps({"key": "val"}))
|
|
||||||
assert read_json_data(str(p)) == {"key": "val"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestJSONLoaderDynamic:
|
|
||||||
def _make_json(self, tmp_path, data):
|
|
||||||
p = tmp_path / "test.json"
|
|
||||||
p.write_text(json.dumps(data))
|
|
||||||
return str(p)
|
|
||||||
|
|
||||||
def test_known_keys(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"name": "alice", "age": 30, "score": 9.5})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="name,age,score")
|
|
||||||
assert result[0] == "alice"
|
|
||||||
assert result[1] == 30
|
|
||||||
assert result[2] == 9.5
|
|
||||||
|
|
||||||
def test_empty_output_keys(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"name": "alice"})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="")
|
|
||||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
|
||||||
assert all(v == "" for v in result)
|
|
||||||
|
|
||||||
def test_pads_to_max(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"a": "1", "b": "2"})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="a,b")
|
|
||||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
|
||||||
assert result[0] == "1"
|
|
||||||
assert result[1] == "2"
|
|
||||||
assert all(v == "" for v in result[2:])
|
|
||||||
|
|
||||||
def test_type_preservation_int(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"count": 42})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="count")
|
|
||||||
assert result[0] == 42
|
|
||||||
assert isinstance(result[0], int)
|
|
||||||
|
|
||||||
def test_type_preservation_float(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"rate": 3.14})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="rate")
|
|
||||||
assert result[0] == 3.14
|
|
||||||
assert isinstance(result[0], float)
|
|
||||||
|
|
||||||
def test_type_preservation_str(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"label": "hello"})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="label")
|
|
||||||
assert result[0] == "hello"
|
|
||||||
assert isinstance(result[0], str)
|
|
||||||
|
|
||||||
def test_bool_becomes_string(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"flag": True, "off": False})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="flag,off")
|
|
||||||
assert result[0] == "true"
|
|
||||||
assert result[1] == "false"
|
|
||||||
assert isinstance(result[0], str)
|
|
||||||
|
|
||||||
def test_missing_key_returns_empty_string(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {"a": "1"})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 1, output_keys="a,nonexistent")
|
|
||||||
assert result[0] == "1"
|
|
||||||
assert result[1] == ""
|
|
||||||
|
|
||||||
def test_missing_file_returns_all_empty(self, tmp_path):
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(str(tmp_path / "nope.json"), 1, output_keys="a,b")
|
|
||||||
assert len(result) == MAX_DYNAMIC_OUTPUTS
|
|
||||||
assert result[0] == ""
|
|
||||||
assert result[1] == ""
|
|
||||||
|
|
||||||
def test_batch_data(self, tmp_path):
|
|
||||||
path = self._make_json(tmp_path, {
|
|
||||||
"batch_data": [
|
|
||||||
{"sequence_number": 1, "x": "first"},
|
|
||||||
{"sequence_number": 2, "x": "second"},
|
|
||||||
]
|
|
||||||
})
|
|
||||||
loader = JSONLoaderDynamic()
|
|
||||||
result = loader.load_dynamic(path, 2, output_keys="x")
|
|
||||||
assert result[0] == "second"
|
|
||||||
211
tests/test_project_loader.py
Normal file
211
tests/test_project_loader.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from project_loader import (
|
||||||
|
ProjectLoaderDynamic,
|
||||||
|
_fetch_json,
|
||||||
|
_fetch_data,
|
||||||
|
_fetch_keys,
|
||||||
|
MAX_DYNAMIC_OUTPUTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_urlopen(data: dict):
|
||||||
|
"""Create a mock context manager for urllib.request.urlopen."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.read.return_value = json.dumps(data).encode()
|
||||||
|
response.__enter__ = lambda s: s
|
||||||
|
response.__exit__ = MagicMock(return_value=False)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TestFetchHelpers:
|
||||||
|
def test_fetch_json_success(self):
|
||||||
|
data = {"key": "value"}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)):
|
||||||
|
result = _fetch_json("http://example.com/api")
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
def test_fetch_json_network_error(self):
|
||||||
|
with patch("project_loader.urllib.request.urlopen", side_effect=OSError("connection refused")):
|
||||||
|
result = _fetch_json("http://example.com/api")
|
||||||
|
assert result["error"] == "network_error"
|
||||||
|
assert "connection refused" in result["message"]
|
||||||
|
|
||||||
|
def test_fetch_json_http_error(self):
|
||||||
|
import urllib.error
|
||||||
|
err = urllib.error.HTTPError(
|
||||||
|
"http://example.com/api", 404, "Not Found", {},
|
||||||
|
BytesIO(json.dumps({"detail": "Project 'x' not found"}).encode())
|
||||||
|
)
|
||||||
|
with patch("project_loader.urllib.request.urlopen", side_effect=err):
|
||||||
|
result = _fetch_json("http://example.com/api")
|
||||||
|
assert result["error"] == "http_error"
|
||||||
|
assert result["status"] == 404
|
||||||
|
assert "not found" in result["message"].lower()
|
||||||
|
|
||||||
|
def test_fetch_data_builds_url(self):
|
||||||
|
data = {"prompt": "hello"}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||||
|
result = _fetch_data("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||||
|
assert result == data
|
||||||
|
called_url = mock.call_args[0][0]
|
||||||
|
assert "/api/projects/proj1/files/batch_i2v/data?seq=1" in called_url
|
||||||
|
|
||||||
|
def test_fetch_keys_builds_url(self):
|
||||||
|
data = {"keys": ["prompt"], "types": ["STRING"]}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||||
|
result = _fetch_keys("http://localhost:8080", "proj1", "batch_i2v", 1)
|
||||||
|
assert result == data
|
||||||
|
called_url = mock.call_args[0][0]
|
||||||
|
assert "/api/projects/proj1/files/batch_i2v/keys?seq=1" in called_url
|
||||||
|
|
||||||
|
def test_fetch_data_strips_trailing_slash(self):
|
||||||
|
data = {"prompt": "hello"}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||||
|
_fetch_data("http://localhost:8080/", "proj1", "file1", 1)
|
||||||
|
called_url = mock.call_args[0][0]
|
||||||
|
assert "//api" not in called_url
|
||||||
|
|
||||||
|
def test_fetch_data_encodes_special_chars(self):
|
||||||
|
"""Project/file names with spaces or special chars should be percent-encoded."""
|
||||||
|
data = {"prompt": "hello"}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||||
|
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
|
||||||
|
called_url = mock.call_args[0][0]
|
||||||
|
assert "my%20project" in called_url
|
||||||
|
assert "batch%20file" in called_url
|
||||||
|
assert " " not in called_url.split("?")[0] # no raw spaces in path
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectLoaderDynamic:
|
||||||
|
def _keys_meta(self, total=5):
|
||||||
|
return {"keys": [], "types": [], "total_sequences": total}
|
||||||
|
|
||||||
|
def test_load_dynamic_with_keys(self):
|
||||||
|
data = {"prompt": "hello", "seed": 42, "cfg": 1.5}
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys="prompt,seed,cfg"
|
||||||
|
)
|
||||||
|
assert result[0] == 5 # total_sequences
|
||||||
|
assert result[1] == "hello"
|
||||||
|
assert result[2] == 42
|
||||||
|
assert result[3] == 1.5
|
||||||
|
assert len(result) == MAX_DYNAMIC_OUTPUTS + 1
|
||||||
|
|
||||||
|
def test_load_dynamic_with_json_encoded_keys(self):
|
||||||
|
"""JSON-encoded output_keys should be parsed correctly."""
|
||||||
|
import json as _json
|
||||||
|
data = {"my,key": "comma_val", "normal": "ok"}
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
keys_json = _json.dumps(["my,key", "normal"])
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys=keys_json
|
||||||
|
)
|
||||||
|
assert result[1] == "comma_val"
|
||||||
|
assert result[2] == "ok"
|
||||||
|
|
||||||
|
def test_load_dynamic_type_coercion(self):
|
||||||
|
"""output_types should coerce values to declared types."""
|
||||||
|
import json as _json
|
||||||
|
data = {"seed": "42", "cfg": "1.5", "prompt": "hello"}
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
keys_json = _json.dumps(["seed", "cfg", "prompt"])
|
||||||
|
types_json = _json.dumps(["INT", "FLOAT", "STRING"])
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys=keys_json, output_types=types_json
|
||||||
|
)
|
||||||
|
assert result[1] == 42 # string "42" coerced to int
|
||||||
|
assert result[2] == 1.5 # string "1.5" coerced to float
|
||||||
|
assert result[3] == "hello" # string stays string
|
||||||
|
|
||||||
|
def test_load_dynamic_empty_keys(self):
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys=""
|
||||||
|
)
|
||||||
|
# Slot 0 is total_sequences (INT), rest are empty strings
|
||||||
|
assert result[0] == 5
|
||||||
|
assert all(v == "" for v in result[1:])
|
||||||
|
|
||||||
|
def test_load_dynamic_missing_key(self):
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value={"prompt": "hello"}):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys="nonexistent"
|
||||||
|
)
|
||||||
|
assert result[1] == ""
|
||||||
|
|
||||||
|
def test_load_dynamic_bool_becomes_string(self):
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value={"flag": True}):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys="flag"
|
||||||
|
)
|
||||||
|
assert result[1] == "true"
|
||||||
|
|
||||||
|
def test_load_dynamic_returns_total_sequences(self):
|
||||||
|
"""total_sequences should be the first output from keys metadata."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
with patch("project_loader._fetch_keys", return_value={"keys": [], "types": [], "total_sequences": 42}):
|
||||||
|
with patch("project_loader._fetch_data", return_value={}):
|
||||||
|
result = node.load_dynamic(
|
||||||
|
"http://localhost:8080", "proj1", "batch_i2v", 1,
|
||||||
|
output_keys=""
|
||||||
|
)
|
||||||
|
assert result[0] == 42
|
||||||
|
|
||||||
|
def test_load_dynamic_raises_on_network_error(self):
|
||||||
|
"""Network errors from _fetch_keys should raise RuntimeError."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||||
|
with patch("project_loader._fetch_keys", return_value=error_resp):
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to fetch project keys"):
|
||||||
|
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||||
|
|
||||||
|
def test_load_dynamic_raises_on_data_fetch_error(self):
|
||||||
|
"""Network errors from _fetch_data should raise RuntimeError."""
|
||||||
|
node = ProjectLoaderDynamic()
|
||||||
|
error_resp = {"error": "http_error", "status": 404, "message": "Sequence not found"}
|
||||||
|
with patch("project_loader._fetch_keys", return_value=self._keys_meta()):
|
||||||
|
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to fetch sequence data"):
|
||||||
|
node.load_dynamic("http://localhost:8080", "proj1", "batch", 1)
|
||||||
|
|
||||||
|
def test_input_types_has_manager_url(self):
|
||||||
|
inputs = ProjectLoaderDynamic.INPUT_TYPES()
|
||||||
|
assert "manager_url" in inputs["required"]
|
||||||
|
assert "project_name" in inputs["required"]
|
||||||
|
assert "file_name" in inputs["required"]
|
||||||
|
assert "sequence_number" in inputs["required"]
|
||||||
|
|
||||||
|
def test_category(self):
|
||||||
|
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNodeMappings:
|
||||||
|
def test_mappings_exist(self):
|
||||||
|
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
|
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||||
|
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
|
||||||
|
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
|
||||||
74
utils.py
74
utils.py
@@ -160,6 +160,80 @@ def get_file_mtime(path: str | Path) -> float:
|
|||||||
return path.stat().st_mtime
|
return path.stat().st_mtime
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||||
|
"""Dual-write helper: sync JSON data to the project database.
|
||||||
|
|
||||||
|
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
||||||
|
and saves the history_tree. All writes happen in a single transaction.
|
||||||
|
"""
|
||||||
|
if not db or not project_name:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
proj = db.get_project(project_name)
|
||||||
|
if not proj:
|
||||||
|
return
|
||||||
|
file_name = Path(file_path).stem
|
||||||
|
|
||||||
|
# Use a single transaction for atomicity
|
||||||
|
db.conn.execute("BEGIN IMMEDIATE")
|
||||||
|
try:
|
||||||
|
df = db.get_data_file(proj["id"], file_name)
|
||||||
|
top_level = {k: v for k, v in data.items()
|
||||||
|
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||||
|
if not df:
|
||||||
|
now = __import__('time').time()
|
||||||
|
cur = db.conn.execute(
|
||||||
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
|
||||||
|
)
|
||||||
|
df_id = cur.lastrowid
|
||||||
|
else:
|
||||||
|
df_id = df["id"]
|
||||||
|
# Update top_level metadata
|
||||||
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(json.dumps(top_level), now, df_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync sequences
|
||||||
|
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||||
|
if isinstance(batch_data, list):
|
||||||
|
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||||
|
for item in batch_data:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||||
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?)",
|
||||||
|
(df_id, seq_num, json.dumps(item), now),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync history tree
|
||||||
|
history_tree = data.get(KEY_HISTORY_TREE)
|
||||||
|
if history_tree and isinstance(history_tree, dict):
|
||||||
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, json.dumps(history_tree), now),
|
||||||
|
)
|
||||||
|
|
||||||
|
db.conn.execute("COMMIT")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
db.conn.execute("ROLLBACK")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"sync_to_db failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def generate_templates(current_dir: Path) -> None:
|
def generate_templates(current_dir: Path) -> None:
|
||||||
"""Creates batch template files if folder is empty."""
|
"""Creates batch template files if folder is empty."""
|
||||||
first = DEFAULTS.copy()
|
first = DEFAULTS.copy()
|
||||||
|
|||||||
@@ -1,140 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { api } from "../../scripts/api.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "json.manager.dynamic",
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
|
||||||
if (nodeData.name !== "JSONLoaderDynamic") return;
|
|
||||||
|
|
||||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
|
||||||
origOnNodeCreated?.apply(this, arguments);
|
|
||||||
|
|
||||||
// Hide internal widgets (managed by JS)
|
|
||||||
for (const name of ["output_keys", "output_types"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove all 32 default outputs from Python RETURN_TYPES
|
|
||||||
while (this.outputs.length > 0) {
|
|
||||||
this.removeOutput(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add Refresh button
|
|
||||||
this.addWidget("button", "Refresh Outputs", null, () => {
|
|
||||||
this.refreshDynamicOutputs();
|
|
||||||
});
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
};
|
|
||||||
|
|
||||||
nodeType.prototype.refreshDynamicOutputs = async function () {
|
|
||||||
const pathWidget = this.widgets?.find(w => w.name === "json_path");
|
|
||||||
const seqWidget = this.widgets?.find(w => w.name === "sequence_number");
|
|
||||||
if (!pathWidget?.value) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const resp = await api.fetchApi(
|
|
||||||
`/json_manager/get_keys?path=${encodeURIComponent(pathWidget.value)}&sequence_number=${seqWidget?.value || 1}`
|
|
||||||
);
|
|
||||||
const { keys, types } = await resp.json();
|
|
||||||
|
|
||||||
// Store keys and types in hidden widgets for persistence
|
|
||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
|
||||||
if (okWidget) okWidget.value = keys.join(",");
|
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
|
||||||
if (otWidget) otWidget.value = types.join(",");
|
|
||||||
|
|
||||||
// Build a map of current output names to slot indices
|
|
||||||
const oldSlots = {};
|
|
||||||
for (let i = 0; i < this.outputs.length; i++) {
|
|
||||||
oldSlots[this.outputs[i].name] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build new outputs, reusing existing slots to preserve links
|
|
||||||
const newOutputs = [];
|
|
||||||
for (let k = 0; k < keys.length; k++) {
|
|
||||||
const key = keys[k];
|
|
||||||
const type = types[k] || "*";
|
|
||||||
if (key in oldSlots) {
|
|
||||||
// Reuse existing slot object (keeps links intact)
|
|
||||||
const slot = this.outputs[oldSlots[key]];
|
|
||||||
slot.type = type;
|
|
||||||
newOutputs.push(slot);
|
|
||||||
delete oldSlots[key];
|
|
||||||
} else {
|
|
||||||
// New key — create a fresh slot
|
|
||||||
newOutputs.push({ name: key, type: type, links: null });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disconnect links on slots that are being removed
|
|
||||||
for (const name in oldSlots) {
|
|
||||||
const idx = oldSlots[name];
|
|
||||||
if (this.outputs[idx]?.links?.length) {
|
|
||||||
for (const linkId of [...this.outputs[idx].links]) {
|
|
||||||
this.graph?.removeLink(linkId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reassign the outputs array and fix link slot indices
|
|
||||||
this.outputs = newOutputs;
|
|
||||||
// Update link origin_slot to match new positions
|
|
||||||
if (this.graph) {
|
|
||||||
for (let i = 0; i < this.outputs.length; i++) {
|
|
||||||
const links = this.outputs[i].links;
|
|
||||||
if (!links) continue;
|
|
||||||
for (const linkId of links) {
|
|
||||||
const link = this.graph.links[linkId];
|
|
||||||
if (link) link.origin_slot = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
app.graph.setDirtyCanvas(true, true);
|
|
||||||
} catch (e) {
|
|
||||||
console.error("[JSONLoaderDynamic] Refresh failed:", e);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Restore state on workflow load
|
|
||||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
|
||||||
nodeType.prototype.onConfigure = function (info) {
|
|
||||||
origOnConfigure?.apply(this, arguments);
|
|
||||||
|
|
||||||
// Hide internal widgets
|
|
||||||
for (const name of ["output_keys", "output_types"]) {
|
|
||||||
const w = this.widgets?.find(w => w.name === name);
|
|
||||||
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
|
||||||
}
|
|
||||||
|
|
||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
|
||||||
|
|
||||||
const keys = okWidget?.value
|
|
||||||
? okWidget.value.split(",").filter(k => k.trim())
|
|
||||||
: [];
|
|
||||||
const types = otWidget?.value
|
|
||||||
? otWidget.value.split(",")
|
|
||||||
: [];
|
|
||||||
|
|
||||||
// On load, LiteGraph already restored serialized outputs with links.
|
|
||||||
// Rename and set types to match stored state (preserves links).
|
|
||||||
for (let i = 0; i < this.outputs.length && i < keys.length; i++) {
|
|
||||||
this.outputs[i].name = keys[i].trim();
|
|
||||||
if (types[i]) this.outputs[i].type = types[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove any extra outputs beyond the key count
|
|
||||||
while (this.outputs.length > keys.length) {
|
|
||||||
this.removeOutput(this.outputs.length - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
|
||||||
};
|
|
||||||
},
|
|
||||||
});
|
|
||||||
255
web/project_dynamic.js
Normal file
255
web/project_dynamic.js
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "json.manager.project.dynamic",
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
|
if (nodeData.name !== "ProjectLoaderDynamic") return;
|
||||||
|
|
||||||
|
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
origOnNodeCreated?.apply(this, arguments);
|
||||||
|
|
||||||
|
// Hide internal widgets (managed by JS)
|
||||||
|
for (const name of ["output_keys", "output_types"]) {
|
||||||
|
const w = this.widgets?.find(w => w.name === name);
|
||||||
|
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do NOT remove default outputs synchronously here.
|
||||||
|
// During graph loading, ComfyUI creates all nodes (firing onNodeCreated)
|
||||||
|
// before configuring them. Other nodes (e.g. Kijai Set/Get) may resolve
|
||||||
|
// links to our outputs during their configure step. If we remove outputs
|
||||||
|
// here, those nodes find no output slot and error out.
|
||||||
|
//
|
||||||
|
// Instead, defer cleanup: for loaded workflows onConfigure sets _configured
|
||||||
|
// before this runs; for new nodes the defaults are cleaned up.
|
||||||
|
this._configured = false;
|
||||||
|
|
||||||
|
// Add Refresh button
|
||||||
|
this.addWidget("button", "Refresh Outputs", null, () => {
|
||||||
|
this.refreshDynamicOutputs();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Auto-refresh with 500ms debounce on widget changes
|
||||||
|
this._refreshTimer = null;
|
||||||
|
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
|
||||||
|
for (const widgetName of autoRefreshWidgets) {
|
||||||
|
const w = this.widgets?.find(w => w.name === widgetName);
|
||||||
|
if (w) {
|
||||||
|
const origCallback = w.callback;
|
||||||
|
const node = this;
|
||||||
|
w.callback = function (...args) {
|
||||||
|
origCallback?.apply(this, args);
|
||||||
|
clearTimeout(node._refreshTimer);
|
||||||
|
node._refreshTimer = setTimeout(() => {
|
||||||
|
node.refreshDynamicOutputs();
|
||||||
|
}, 500);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
queueMicrotask(() => {
|
||||||
|
if (!this._configured) {
|
||||||
|
// New node (not loading) — remove the Python default outputs
|
||||||
|
// and add only the fixed total_sequences slot
|
||||||
|
while (this.outputs.length > 0) {
|
||||||
|
this.removeOutput(0);
|
||||||
|
}
|
||||||
|
this.addOutput("total_sequences", "INT");
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
nodeType.prototype._setStatus = function (status, message) {
|
||||||
|
const baseTitle = "Project Loader (Dynamic)";
|
||||||
|
if (status === "ok") {
|
||||||
|
this.title = baseTitle;
|
||||||
|
this.color = undefined;
|
||||||
|
this.bgcolor = undefined;
|
||||||
|
} else if (status === "error") {
|
||||||
|
this.title = baseTitle + " - ERROR";
|
||||||
|
this.color = "#ff4444";
|
||||||
|
this.bgcolor = "#331111";
|
||||||
|
if (message) this.title = baseTitle + ": " + message;
|
||||||
|
} else if (status === "loading") {
|
||||||
|
this.title = baseTitle + " - Loading...";
|
||||||
|
}
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
};
|
||||||
|
|
||||||
|
nodeType.prototype.refreshDynamicOutputs = async function () {
|
||||||
|
const urlWidget = this.widgets?.find(w => w.name === "manager_url");
|
||||||
|
const projectWidget = this.widgets?.find(w => w.name === "project_name");
|
||||||
|
const fileWidget = this.widgets?.find(w => w.name === "file_name");
|
||||||
|
const seqWidget = this.widgets?.find(w => w.name === "sequence_number");
|
||||||
|
|
||||||
|
if (!urlWidget?.value || !projectWidget?.value || !fileWidget?.value) return;
|
||||||
|
|
||||||
|
this._setStatus("loading");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resp = await api.fetchApi(
|
||||||
|
`/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}`
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!resp.ok) {
|
||||||
|
let errorMsg = `HTTP ${resp.status}`;
|
||||||
|
try {
|
||||||
|
const errData = await resp.json();
|
||||||
|
if (errData.message) errorMsg = errData.message;
|
||||||
|
} catch (_) {}
|
||||||
|
this._setStatus("error", errorMsg);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await resp.json();
|
||||||
|
const keys = data.keys;
|
||||||
|
const types = data.types;
|
||||||
|
|
||||||
|
// If the API returned an error or missing data, keep existing outputs and links intact
|
||||||
|
if (data.error || !Array.isArray(keys) || !Array.isArray(types)) {
|
||||||
|
const errMsg = data.error ? data.message || data.error : "Missing keys/types";
|
||||||
|
this._setStatus("error", errMsg);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store keys and types in hidden widgets for persistence (comma-separated)
|
||||||
|
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||||
|
if (okWidget) okWidget.value = keys.join(",");
|
||||||
|
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||||
|
if (otWidget) otWidget.value = types.join(",");
|
||||||
|
|
||||||
|
// Slot 0 is always total_sequences (INT) — ensure it exists
|
||||||
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
|
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||||
|
}
|
||||||
|
this.outputs[0].type = "INT";
|
||||||
|
|
||||||
|
// Build a map of current dynamic output names to slot indices (skip slot 0)
|
||||||
|
const oldSlots = {};
|
||||||
|
for (let i = 1; i < this.outputs.length; i++) {
|
||||||
|
oldSlots[this.outputs[i].name] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build new dynamic outputs, reusing existing slots to preserve links
|
||||||
|
const newOutputs = [this.outputs[0]]; // Keep total_sequences at slot 0
|
||||||
|
for (let k = 0; k < keys.length; k++) {
|
||||||
|
const key = keys[k];
|
||||||
|
const type = types[k] || "*";
|
||||||
|
if (key in oldSlots) {
|
||||||
|
const slot = this.outputs[oldSlots[key]];
|
||||||
|
slot.type = type;
|
||||||
|
slot.label = key;
|
||||||
|
newOutputs.push(slot);
|
||||||
|
delete oldSlots[key];
|
||||||
|
} else {
|
||||||
|
newOutputs.push({ name: key, label: key, type: type, links: null });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect links on slots that are being removed
|
||||||
|
for (const name in oldSlots) {
|
||||||
|
const idx = oldSlots[name];
|
||||||
|
if (this.outputs[idx]?.links?.length) {
|
||||||
|
for (const linkId of [...this.outputs[idx].links]) {
|
||||||
|
this.graph?.removeLink(linkId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reassign the outputs array and fix link slot indices
|
||||||
|
this.outputs = newOutputs;
|
||||||
|
if (this.graph) {
|
||||||
|
for (let i = 0; i < this.outputs.length; i++) {
|
||||||
|
const links = this.outputs[i].links;
|
||||||
|
if (!links) continue;
|
||||||
|
for (const linkId of links) {
|
||||||
|
const link = this.graph.links[linkId];
|
||||||
|
if (link) link.origin_slot = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this._setStatus("ok");
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
} catch (e) {
|
||||||
|
console.error("[ProjectLoaderDynamic] Refresh failed:", e);
|
||||||
|
this._setStatus("error", "Server unreachable");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Restore state on workflow load
|
||||||
|
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function (info) {
|
||||||
|
origOnConfigure?.apply(this, arguments);
|
||||||
|
this._configured = true;
|
||||||
|
|
||||||
|
// Hide internal widgets
|
||||||
|
for (const name of ["output_keys", "output_types"]) {
|
||||||
|
const w = this.widgets?.find(w => w.name === name);
|
||||||
|
if (w) { w.type = "hidden"; w.computeSize = () => [0, -4]; }
|
||||||
|
}
|
||||||
|
|
||||||
|
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||||
|
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||||
|
|
||||||
|
const keys = okWidget?.value
|
||||||
|
? okWidget.value.split(",").filter(k => k.trim())
|
||||||
|
: [];
|
||||||
|
const types = otWidget?.value
|
||||||
|
? otWidget.value.split(",")
|
||||||
|
: [];
|
||||||
|
|
||||||
|
// Ensure slot 0 is total_sequences (INT)
|
||||||
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
|
this.outputs.unshift({ name: "total_sequences", type: "INT", links: null });
|
||||||
|
const node = this;
|
||||||
|
queueMicrotask(() => {
|
||||||
|
if (!node.graph) return;
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
output.links = null;
|
||||||
|
}
|
||||||
|
for (const linkId in node.graph.links) {
|
||||||
|
const link = node.graph.links[linkId];
|
||||||
|
if (!link || link.origin_id !== node.id) continue;
|
||||||
|
link.origin_slot += 1;
|
||||||
|
const output = node.outputs[link.origin_slot];
|
||||||
|
if (output) {
|
||||||
|
if (!output.links) output.links = [];
|
||||||
|
output.links.push(link.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
this.outputs[0].type = "INT";
|
||||||
|
this.outputs[0].name = "total_sequences";
|
||||||
|
|
||||||
|
if (keys.length > 0) {
|
||||||
|
for (let i = 0; i < keys.length; i++) {
|
||||||
|
const slotIdx = i + 1;
|
||||||
|
if (slotIdx < this.outputs.length) {
|
||||||
|
this.outputs[slotIdx].name = keys[i].trim();
|
||||||
|
this.outputs[slotIdx].label = keys[i].trim();
|
||||||
|
if (types[i]) this.outputs[slotIdx].type = types[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (this.outputs.length > keys.length + 1) {
|
||||||
|
this.removeOutput(this.outputs.length - 1);
|
||||||
|
}
|
||||||
|
} else if (this.outputs.length > 1) {
|
||||||
|
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
||||||
|
const dynamicOutputs = this.outputs.slice(1);
|
||||||
|
if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(",");
|
||||||
|
if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(",");
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user