Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ba8ce45846 | |||
| 1ec3abf17a | |||
| 686d4687c3 | |||
| a37dd82ae3 | |||
| 3b11a4e974 | |||
| 5eb82f8ff6 | |||
| bf598ebf80 | |||
| 6e232da193 | |||
| ff5802ab63 | |||
| 413e1c09e9 | |||
| 672b28e27f | |||
| 3dc91319a2 | |||
| bd36b4b725 | |||
| 77eb3473ab | |||
| 2cf8cc1f0a | |||
| 545b864c08 | |||
| ad6cd76b08 | |||
| bd7d314ae8 | |||
| 628b256981 | |||
| fb007920ee | |||
| d3955c489b | |||
| e575a78893 | |||
| a1a85ecc4d | |||
| eac4e4f08b | |||
| 79e1426036 | |||
| ba330dd208 | |||
| 9c560ccfd0 | |||
| 480131e327 | |||
| fac5013359 | |||
| 45ce264675 | |||
| 0f134a1a20 | |||
| a9197efacd | |||
| ecb5cdc13f | |||
| 1386043f69 | |||
| c4700c620d | |||
| 589c84fd95 | |||
| 37e9e1001e | |||
| 526af7097d | |||
| c880c16865 | |||
| 82e4ba526c | |||
| 08338746e2 | |||
| 15047016b9 | |||
| 29aa87ee00 | |||
| be9c95ffbd | |||
| 074e36f883 | |||
| b36200faaa | |||
| 5aac1677f7 | |||
| f3ad3e01bc | |||
| efd0a31426 | |||
| b042fe4368 | |||
| 04b9ed0e27 | |||
| 1b8d13f7c4 | |||
| 497e6b06fb | |||
| 993fc86070 | |||
| c9bcc735f4 | |||
| dc8f44f02b | |||
| 2a6b4f5245 | |||
| 60d1162700 | |||
| 204fc4ea85 | |||
| 033b3415c2 | |||
| 2ccc3821d6 | |||
| 615755ba44 | |||
| 4b09491242 |
@@ -4,6 +4,7 @@ All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Query
|
from fastapi import HTTPException, Query
|
||||||
@@ -54,6 +55,7 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
|
t0 = time.perf_counter()
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
proj = db.get_project(name)
|
proj = db.get_project(name)
|
||||||
if not proj:
|
if not proj:
|
||||||
@@ -64,10 +66,13 @@ def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[st
|
|||||||
data = db.get_sequence(df["id"], seq)
|
data = db.get_sequence(df["id"], seq)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||||
|
logger.info("API _get_data %s/%s seq=%d (%d keys): %.3fs",
|
||||||
|
name, file_name, seq, len(data), time.perf_counter() - t0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
|
t0 = time.perf_counter()
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
proj = db.get_project(name)
|
proj = db.get_project(name)
|
||||||
if not proj:
|
if not proj:
|
||||||
@@ -77,4 +82,6 @@ def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[st
|
|||||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
keys, types = db.get_sequence_keys(df["id"], seq)
|
keys, types = db.get_sequence_keys(df["id"], seq)
|
||||||
total = db.count_sequences(df["id"])
|
total = db.count_sequences(df["id"])
|
||||||
|
logger.info("API _get_keys %s/%s seq=%d (%d keys): %.3fs",
|
||||||
|
name, file_name, seq, len(keys), time.perf_counter() - t0)
|
||||||
return {"keys": keys, "types": types, "total_sequences": total}
|
return {"keys": keys, "types": types, "total_sequences": total}
|
||||||
|
|||||||
@@ -47,6 +47,19 @@ CREATE TABLE IF NOT EXISTS history_trees (
|
|||||||
tree_data TEXT NOT NULL DEFAULT '{}',
|
tree_data TEXT NOT NULL DEFAULT '{}',
|
||||||
updated_at REAL NOT NULL
|
updated_at REAL NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS history_snapshots (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
||||||
|
node_id TEXT NOT NULL,
|
||||||
|
snapshot_data TEXT NOT NULL,
|
||||||
|
updated_at REAL NOT NULL,
|
||||||
|
UNIQUE(data_file_id, node_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_data_files_project_id ON data_files(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sequences_data_file_id ON sequences(data_file_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_history_snapshots_df ON history_snapshots(data_file_id);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -65,6 +78,31 @@ class ProjectDB:
|
|||||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||||
self.conn.execute("PRAGMA foreign_keys=ON")
|
self.conn.execute("PRAGMA foreign_keys=ON")
|
||||||
self.conn.executescript(SCHEMA_SQL)
|
self.conn.executescript(SCHEMA_SQL)
|
||||||
|
self._migrate_all_lora_data()
|
||||||
|
|
||||||
|
def _migrate_all_lora_data(self) -> None:
|
||||||
|
"""Bulk migration: split combined lora 'name:strength' into separate keys."""
|
||||||
|
rows = self.conn.execute("SELECT id, data FROM sequences").fetchall()
|
||||||
|
updated = 0
|
||||||
|
self.conn.execute("BEGIN")
|
||||||
|
try:
|
||||||
|
for row in rows:
|
||||||
|
data = json.loads(row["data"])
|
||||||
|
original = row["data"]
|
||||||
|
migrated = self._migrate_lora_keys(data)
|
||||||
|
new_json = json.dumps(migrated)
|
||||||
|
if new_json != original:
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE sequences SET data = ? WHERE id = ?",
|
||||||
|
(new_json, row["id"]),
|
||||||
|
)
|
||||||
|
updated += 1
|
||||||
|
self.conn.execute("COMMIT")
|
||||||
|
except Exception:
|
||||||
|
self.conn.execute("ROLLBACK")
|
||||||
|
raise
|
||||||
|
if updated:
|
||||||
|
logger.info("Migrated lora keys in %d/%d sequences", updated, len(rows))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
@@ -90,6 +128,16 @@ class ProjectDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
def list_projects_with_file_counts(self) -> list[dict]:
|
||||||
|
"""List projects with data file counts in a single query."""
|
||||||
|
rows = self.conn.execute(
|
||||||
|
"SELECT p.id, p.name, p.folder_path, p.description, p.created_at, p.updated_at, "
|
||||||
|
"COUNT(df.id) AS file_count "
|
||||||
|
"FROM projects p LEFT JOIN data_files df ON df.project_id = p.id "
|
||||||
|
"GROUP BY p.id ORDER BY p.name"
|
||||||
|
).fetchall()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
def get_project(self, name: str) -> dict | None:
|
def get_project(self, name: str) -> dict | None:
|
||||||
row = self.conn.execute(
|
row = self.conn.execute(
|
||||||
"SELECT id, name, folder_path, description, created_at, updated_at "
|
"SELECT id, name, folder_path, description, created_at, updated_at "
|
||||||
@@ -98,6 +146,24 @@ class ProjectDB:
|
|||||||
).fetchone()
|
).fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
def rename_project(self, old_name: str, new_name: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
cur = self.conn.execute(
|
||||||
|
"UPDATE projects SET name = ?, updated_at = ? WHERE name = ?",
|
||||||
|
(new_name, now, old_name),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
|
def update_project_path(self, name: str, folder_path: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
cur = self.conn.execute(
|
||||||
|
"UPDATE projects SET folder_path = ?, updated_at = ? WHERE name = ?",
|
||||||
|
(folder_path, now, name),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
return cur.rowcount > 0
|
||||||
|
|
||||||
def delete_project(self, name: str) -> bool:
|
def delete_project(self, name: str) -> bool:
|
||||||
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
@@ -128,6 +194,14 @@ class ProjectDB:
|
|||||||
).fetchall()
|
).fetchall()
|
||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
def count_data_files(self, project_id: int) -> int:
|
||||||
|
"""Return the number of data files for a project."""
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT COUNT(*) AS cnt FROM data_files WHERE project_id = ?",
|
||||||
|
(project_id,),
|
||||||
|
).fetchone()
|
||||||
|
return row["cnt"]
|
||||||
|
|
||||||
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
||||||
row = self.conn.execute(
|
row = self.conn.execute(
|
||||||
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
||||||
@@ -168,12 +242,52 @@ class ProjectDB:
|
|||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _migrate_lora_keys(data: dict) -> dict:
|
||||||
|
"""Split combined lora 'name:strength' into separate name and strength keys."""
|
||||||
|
for idx in range(1, 4):
|
||||||
|
for tier in ('high', 'low'):
|
||||||
|
name_key = f'lora {idx} {tier}'
|
||||||
|
str_key = f'lora {idx} {tier} strength'
|
||||||
|
raw = str(data.get(name_key, ''))
|
||||||
|
if raw.startswith('<lora:'):
|
||||||
|
inner = raw.replace('<lora:', '').replace('>', '')
|
||||||
|
if ':' in inner:
|
||||||
|
parts = inner.rsplit(':', 1)
|
||||||
|
data[name_key] = parts[0]
|
||||||
|
try:
|
||||||
|
data[str_key] = float(parts[1])
|
||||||
|
except ValueError:
|
||||||
|
data[str_key] = 1.0
|
||||||
|
else:
|
||||||
|
data[name_key] = inner
|
||||||
|
if str_key not in data:
|
||||||
|
data[str_key] = 1.0
|
||||||
|
elif ':' in raw and raw:
|
||||||
|
parts = raw.rsplit(':', 1)
|
||||||
|
try:
|
||||||
|
strength = float(parts[1])
|
||||||
|
data[name_key] = parts[0]
|
||||||
|
data[str_key] = strength
|
||||||
|
except ValueError:
|
||||||
|
if str_key not in data:
|
||||||
|
data[str_key] = 1.0
|
||||||
|
elif raw:
|
||||||
|
# Name exists without colon, ensure strength key exists
|
||||||
|
if str_key not in data:
|
||||||
|
data[str_key] = 1.0
|
||||||
|
# If name is empty, don't add a strength key
|
||||||
|
return data
|
||||||
|
|
||||||
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
||||||
row = self.conn.execute(
|
row = self.conn.execute(
|
||||||
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
||||||
(data_file_id, sequence_number),
|
(data_file_id, sequence_number),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
return json.loads(row["data"]) if row else None
|
if not row:
|
||||||
|
return None
|
||||||
|
data = json.loads(row["data"])
|
||||||
|
return self._migrate_lora_keys(data)
|
||||||
|
|
||||||
def list_sequences(self, data_file_id: int) -> list[int]:
|
def list_sequences(self, data_file_id: int) -> list[int]:
|
||||||
rows = self.conn.execute(
|
rows = self.conn.execute(
|
||||||
@@ -225,22 +339,80 @@ class ProjectDB:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
||||||
|
"""Save history tree, extracting snapshot data into separate table.
|
||||||
|
|
||||||
|
Supports both new format (snapshots dict) and old format (nodes dict).
|
||||||
|
"""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
if "snapshots" in tree_data:
|
||||||
|
entries = tree_data.get("snapshots", {})
|
||||||
|
entry_key = "snapshots"
|
||||||
|
else:
|
||||||
|
entries = tree_data.get("nodes", {})
|
||||||
|
entry_key = "nodes"
|
||||||
|
slim_tree = dict(tree_data)
|
||||||
|
slim_entries = {}
|
||||||
|
for eid, entry in entries.items():
|
||||||
|
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||||
|
slim_tree[entry_key] = slim_entries
|
||||||
|
|
||||||
|
self.conn.execute("BEGIN IMMEDIATE")
|
||||||
|
try:
|
||||||
|
for eid, entry in entries.items():
|
||||||
|
snap = entry.get("data")
|
||||||
|
if snap:
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||||
|
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||||
|
(data_file_id, eid, json.dumps(snap), now),
|
||||||
|
)
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
"VALUES (?, ?, ?) "
|
"VALUES (?, ?, ?) "
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
(data_file_id, json.dumps(tree_data), now),
|
(data_file_id, json.dumps(slim_tree), now),
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.execute("COMMIT")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
self.conn.execute("ROLLBACK")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
def get_history_tree(self, data_file_id: int) -> dict | None:
|
def get_history_tree(self, data_file_id: int) -> dict | None:
|
||||||
|
"""Load history tree metadata (without snapshot data)."""
|
||||||
row = self.conn.execute(
|
row = self.conn.execute(
|
||||||
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
||||||
(data_file_id,),
|
(data_file_id,),
|
||||||
).fetchone()
|
).fetchone()
|
||||||
return json.loads(row["tree_data"]) if row else None
|
return json.loads(row["tree_data"]) if row else None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# History snapshots (per-node data, loaded on demand)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_node_snapshot(self, data_file_id: int, node_id: str) -> dict | None:
|
||||||
|
"""Load a single node's snapshot data on demand."""
|
||||||
|
row = self.conn.execute(
|
||||||
|
"SELECT snapshot_data FROM history_snapshots WHERE data_file_id = ? AND node_id = ?",
|
||||||
|
(data_file_id, node_id),
|
||||||
|
).fetchone()
|
||||||
|
return json.loads(row["snapshot_data"]) if row else None
|
||||||
|
|
||||||
|
def delete_node_snapshots(self, data_file_id: int, node_ids: set) -> None:
|
||||||
|
"""Delete snapshots for removed nodes."""
|
||||||
|
if not node_ids:
|
||||||
|
return
|
||||||
|
placeholders = ",".join("?" for _ in node_ids)
|
||||||
|
self.conn.execute(
|
||||||
|
f"DELETE FROM history_snapshots WHERE data_file_id = ? AND node_id IN ({placeholders})",
|
||||||
|
(data_file_id, *node_ids),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Import
|
# Import
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -296,15 +468,36 @@ class ProjectDB:
|
|||||||
(df_id, seq_num, json.dumps(item), now),
|
(df_id, seq_num, json.dumps(item), now),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import history tree
|
# Import history tree (extract snapshots into separate table)
|
||||||
|
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||||
history_tree = data.get(KEY_HISTORY_TREE)
|
history_tree = data.get(KEY_HISTORY_TREE)
|
||||||
if history_tree and isinstance(history_tree, dict):
|
if history_tree and isinstance(history_tree, dict):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
if "snapshots" in history_tree:
|
||||||
|
entries = history_tree.get("snapshots", {})
|
||||||
|
entry_key = "snapshots"
|
||||||
|
else:
|
||||||
|
entries = history_tree.get("nodes", {})
|
||||||
|
entry_key = "nodes"
|
||||||
|
slim_tree = dict(history_tree)
|
||||||
|
slim_entries = {}
|
||||||
|
for eid, entry in entries.items():
|
||||||
|
snap = entry.get("data")
|
||||||
|
if snap:
|
||||||
|
self.conn.execute(
|
||||||
|
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||||
|
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, eid, json.dumps(snap), now),
|
||||||
|
)
|
||||||
|
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||||
|
slim_tree[entry_key] = slim_entries
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
"VALUES (?, ?, ?) "
|
"VALUES (?, ?, ?) "
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
(df_id, json.dumps(history_tree), now),
|
(df_id, json.dumps(slim_tree), now),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conn.execute("COMMIT")
|
self.conn.execute("COMMIT")
|
||||||
@@ -316,6 +509,60 @@ class ProjectDB:
|
|||||||
pass
|
pass
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Full data reconstruction (replaces load_json for DB-backed files)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def load_full_data(self, project_name: str, file_name: str) -> dict | None:
|
||||||
|
"""Reconstruct the full data dict from DB, matching load_json format.
|
||||||
|
|
||||||
|
Returns None if the project or file doesn't exist in the DB.
|
||||||
|
Result has the same structure as a JSON file: top-level keys +
|
||||||
|
batch_data list + history_tree dict.
|
||||||
|
"""
|
||||||
|
t0 = time.time()
|
||||||
|
df = self.get_data_file_by_names(project_name, file_name)
|
||||||
|
if not df:
|
||||||
|
return None
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
# Start with top-level keys
|
||||||
|
data = df.get("top_level", {})
|
||||||
|
if isinstance(data, str):
|
||||||
|
data = json.loads(data)
|
||||||
|
|
||||||
|
# Load all sequences as batch_data
|
||||||
|
# Group sub-segments (>=1000) after their parent: parent = seq_num / 1000
|
||||||
|
rows = self.conn.execute(
|
||||||
|
"SELECT data FROM sequences WHERE data_file_id = ? "
|
||||||
|
"ORDER BY CASE WHEN sequence_number >= 1000 THEN sequence_number / 1000 "
|
||||||
|
"ELSE sequence_number END, "
|
||||||
|
"CASE WHEN sequence_number >= 1000 THEN 1 ELSE 0 END, "
|
||||||
|
"sequence_number",
|
||||||
|
(df["id"],),
|
||||||
|
).fetchall()
|
||||||
|
batch_data = []
|
||||||
|
for row in rows:
|
||||||
|
seq = json.loads(row["data"])
|
||||||
|
self._migrate_lora_keys(seq)
|
||||||
|
batch_data.append(seq)
|
||||||
|
data["batch_data"] = batch_data
|
||||||
|
t2 = time.time()
|
||||||
|
|
||||||
|
# Load history tree (metadata only, no snapshot data)
|
||||||
|
tree = self.get_history_tree(df["id"])
|
||||||
|
if tree:
|
||||||
|
# Strip any residual snapshot data (supports both formats)
|
||||||
|
for entry in tree.get("snapshots", tree.get("nodes", {})).values():
|
||||||
|
entry.pop("data", None)
|
||||||
|
data["history_tree"] = tree
|
||||||
|
t3 = time.time()
|
||||||
|
|
||||||
|
logger.info("load_full_data %s/%s (%d seqs): lookup=%.3fs seqs=%.3fs tree=%.3fs total=%.3fs",
|
||||||
|
project_name, file_name, len(batch_data),
|
||||||
|
t1 - t0, t2 - t1, t3 - t2, t3 - t0)
|
||||||
|
return data
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Query helpers (for REST API)
|
# Query helpers (for REST API)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
+35
-7
@@ -1,3 +1,4 @@
|
|||||||
|
import html
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -17,7 +18,10 @@ class HistoryTree:
|
|||||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||||
parent = None
|
parent = None
|
||||||
for item in reversed(old_list):
|
for item in reversed(old_list):
|
||||||
|
for _ in range(10):
|
||||||
node_id = str(uuid.uuid4())[:8]
|
node_id = str(uuid.uuid4())[:8]
|
||||||
|
if node_id not in self.nodes:
|
||||||
|
break
|
||||||
self.nodes[node_id] = {
|
self.nodes[node_id] = {
|
||||||
"id": node_id, "parent": parent, "timestamp": time.time(),
|
"id": node_id, "parent": parent, "timestamp": time.time(),
|
||||||
"data": item, "note": item.get("note", "Legacy Import")
|
"data": item, "note": item.get("note", "Legacy Import")
|
||||||
@@ -27,7 +31,13 @@ class HistoryTree:
|
|||||||
self.head_id = parent
|
self.head_id = parent
|
||||||
|
|
||||||
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
|
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
|
||||||
|
# Generate unique node ID with collision check
|
||||||
|
for _ in range(10):
|
||||||
new_id = str(uuid.uuid4())[:8]
|
new_id = str(uuid.uuid4())[:8]
|
||||||
|
if new_id not in self.nodes:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to generate unique node ID after 10 attempts")
|
||||||
|
|
||||||
# Cycle detection: walk parent chain from head to verify no cycle
|
# Cycle detection: walk parent chain from head to verify no cycle
|
||||||
if self.head_id:
|
if self.head_id:
|
||||||
@@ -38,7 +48,7 @@ class HistoryTree:
|
|||||||
raise ValueError(f"Cycle detected in history tree at node {current}")
|
raise ValueError(f"Cycle detected in history tree at node {current}")
|
||||||
visited.add(current)
|
visited.add(current)
|
||||||
node = self.nodes.get(current)
|
node = self.nodes.get(current)
|
||||||
current = node["parent"] if node else None
|
current = node.get("parent") if node else None
|
||||||
|
|
||||||
active_branch = None
|
active_branch = None
|
||||||
for b_name, tip_id in self.branches.items():
|
for b_name, tip_id in self.branches.items():
|
||||||
@@ -66,6 +76,11 @@ class HistoryTree:
|
|||||||
return self.nodes[node_id]["data"]
|
return self.nodes[node_id]["data"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def strip_snapshots(self) -> None:
|
||||||
|
"""Remove snapshot data from all nodes to free memory."""
|
||||||
|
for node in self.nodes.values():
|
||||||
|
node.pop("data", None)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
||||||
|
|
||||||
@@ -114,9 +129,14 @@ class HistoryTree:
|
|||||||
# Build reverse lookup: node_id -> branch name (walk each branch ancestry)
|
# Build reverse lookup: node_id -> branch name (walk each branch ancestry)
|
||||||
node_to_branch: dict[str, str] = {}
|
node_to_branch: dict[str, str] = {}
|
||||||
for b_name, tip_id in self.branches.items():
|
for b_name, tip_id in self.branches.items():
|
||||||
|
visited = set()
|
||||||
current = tip_id
|
current = tip_id
|
||||||
while current and current in self.nodes:
|
while current and current in self.nodes:
|
||||||
if current not in node_to_branch:
|
if current in visited:
|
||||||
|
break
|
||||||
|
if current in node_to_branch:
|
||||||
|
break # this node and all ancestors already assigned
|
||||||
|
visited.add(current)
|
||||||
node_to_branch[current] = b_name
|
node_to_branch[current] = b_name
|
||||||
current = self.nodes[current].get('parent')
|
current = self.nodes[current].get('parent')
|
||||||
|
|
||||||
@@ -154,13 +174,14 @@ class HistoryTree:
|
|||||||
full_note = n.get('note', 'Step')
|
full_note = n.get('note', 'Step')
|
||||||
|
|
||||||
display_note = (full_note[:max_note_len] + '..') if len(full_note) > max_note_len else full_note
|
display_note = (full_note[:max_note_len] + '..') if len(full_note) > max_note_len else full_note
|
||||||
|
display_note = html.escape(display_note)
|
||||||
|
|
||||||
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp']))
|
||||||
|
|
||||||
# Branch label for tip nodes
|
# Branch label for tip nodes
|
||||||
branch_label = ""
|
branch_label = ""
|
||||||
if nid in tip_to_branches:
|
if nid in tip_to_branches:
|
||||||
branch_label = ", ".join(tip_to_branches[nid])
|
branch_label = html.escape(", ".join(tip_to_branches[nid]))
|
||||||
|
|
||||||
# COLORS — per-branch tint, overridden for HEAD and tips
|
# COLORS — per-branch tint, overridden for HEAD and tips
|
||||||
b_name = node_to_branch.get(nid)
|
b_name = node_to_branch.get(nid)
|
||||||
@@ -190,11 +211,18 @@ class HistoryTree:
|
|||||||
+ '</TABLE>>'
|
+ '</TABLE>>'
|
||||||
)
|
)
|
||||||
|
|
||||||
safe_tooltip = full_note.replace('"', "'")
|
safe_tooltip = (full_note
|
||||||
dot.append(f' "{nid}" [label={label}, tooltip="{safe_tooltip}"];')
|
.replace('\\', '\\\\')
|
||||||
|
.replace('"', '\\"')
|
||||||
|
.replace('\n', ' ')
|
||||||
|
.replace('\r', '')
|
||||||
|
.replace(']', ']'))
|
||||||
|
safe_nid = nid.replace('"', '_')
|
||||||
|
dot.append(f' "{safe_nid}" [label={label}, tooltip="{safe_tooltip}"];')
|
||||||
|
|
||||||
if n["parent"] and n["parent"] in self.nodes:
|
if n.get("parent") and n["parent"] in self.nodes:
|
||||||
dot.append(f' "{n["parent"]}" -> "{nid}";')
|
safe_parent = n["parent"].replace('"', '_')
|
||||||
|
dot.append(f' "{safe_parent}" -> "{safe_nid}";')
|
||||||
|
|
||||||
dot.append("}")
|
dot.append("}")
|
||||||
return "\n".join(dot)
|
return "\n".join(dot)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -9,7 +11,7 @@ from utils import (
|
|||||||
load_config, save_config, load_snippets, save_snippets,
|
load_config, save_config, load_snippets, save_snippets,
|
||||||
load_json, save_json, generate_templates, DEFAULTS,
|
load_json, save_json, generate_templates, DEFAULTS,
|
||||||
KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER,
|
KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER,
|
||||||
resolve_path_case_insensitive,
|
resolve_path_case_insensitive, sync_to_db,
|
||||||
)
|
)
|
||||||
from tab_batch_ng import render_batch_processor
|
from tab_batch_ng import render_batch_processor
|
||||||
from tab_timeline_ng import render_timeline_tab
|
from tab_timeline_ng import render_timeline_tab
|
||||||
@@ -156,6 +158,20 @@ def index():
|
|||||||
background: rgba(255,255,255,0.2);
|
background: rgba(255,255,255,0.2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Sub-sequence accent colors (per sub-index, cycling) */
|
||||||
|
.body--dark .subsegment-color-0 > .q-expansion-item__container > .q-item { border-left: 6px solid #06B6D4; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-0 .q-expansion-item__toggle-icon { color: #06B6D4 !important; }
|
||||||
|
.body--dark .subsegment-color-1 > .q-expansion-item__container > .q-item { border-left: 6px solid #A78BFA; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-1 .q-expansion-item__toggle-icon { color: #A78BFA !important; }
|
||||||
|
.body--dark .subsegment-color-2 > .q-expansion-item__container > .q-item { border-left: 6px solid #34D399; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-2 .q-expansion-item__toggle-icon { color: #34D399 !important; }
|
||||||
|
.body--dark .subsegment-color-3 > .q-expansion-item__container > .q-item { border-left: 6px solid #F472B6; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-3 .q-expansion-item__toggle-icon { color: #F472B6 !important; }
|
||||||
|
.body--dark .subsegment-color-4 > .q-expansion-item__container > .q-item { border-left: 6px solid #FBBF24; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-4 .q-expansion-item__toggle-icon { color: #FBBF24 !important; }
|
||||||
|
.body--dark .subsegment-color-5 > .q-expansion-item__container > .q-item { border-left: 6px solid #FB923C; padding-left: 10px; }
|
||||||
|
.body--dark .subsegment-color-5 .q-expansion-item__toggle-icon { color: #FB923C !important; }
|
||||||
|
|
||||||
/* Secondary pane teal accent */
|
/* Secondary pane teal accent */
|
||||||
.pane-secondary .q-field--outlined.q-field--focused .q-field__control:after {
|
.pane-secondary .q-field--outlined.q-field--focused .q-field__control:after {
|
||||||
border-color: #06B6D4 !important;
|
border-color: #06B6D4 !important;
|
||||||
@@ -184,6 +200,9 @@ def index():
|
|||||||
|
|
||||||
@ui.refreshable
|
@ui.refreshable
|
||||||
def render_main_content():
|
def render_main_content():
|
||||||
|
import time as _time
|
||||||
|
_t0 = _time.perf_counter()
|
||||||
|
logger.info("render_main_content START")
|
||||||
max_w = '2400px' if dual_pane['active'] else '1200px'
|
max_w = '2400px' if dual_pane['active'] else '1200px'
|
||||||
with ui.column().classes('w-full q-pa-md').style(f'max-width: {max_w}; margin: 0 auto'):
|
with ui.column().classes('w-full q-pa-md').style(f'max-width: {max_w}; margin: 0 auto'):
|
||||||
if not state.file_path or not state.file_path.exists():
|
if not state.file_path or not state.file_path.exists():
|
||||||
@@ -214,6 +233,8 @@ def index():
|
|||||||
with ui.expansion('ComfyUI Monitor', icon='dns').classes('w-full'):
|
with ui.expansion('ComfyUI Monitor', icon='dns').classes('w-full'):
|
||||||
render_comfy_monitor(state)
|
render_comfy_monitor(state)
|
||||||
|
|
||||||
|
logger.info("render_main_content END (%.3fs)", _time.perf_counter() - _t0)
|
||||||
|
|
||||||
@ui.refreshable
|
@ui.refreshable
|
||||||
def _render_batch_tab_content():
|
def _render_batch_tab_content():
|
||||||
def on_toggle(e):
|
def on_toggle(e):
|
||||||
@@ -255,17 +276,39 @@ def index():
|
|||||||
|
|
||||||
current_val = pane_state.file_path.name if pane_state.file_path else None
|
current_val = pane_state.file_path.name if pane_state.file_path else None
|
||||||
|
|
||||||
def on_select(e):
|
async def on_select(e):
|
||||||
if not e.value:
|
if not e.value:
|
||||||
return
|
return
|
||||||
|
import time as _time
|
||||||
|
_t0 = _time.perf_counter()
|
||||||
|
logger.info("on_select START: %s", e.value)
|
||||||
fp = pane_state.current_dir / e.value
|
fp = pane_state.current_dir / e.value
|
||||||
data, mtime = load_json(fp)
|
file_stem = fp.stem
|
||||||
|
data = None
|
||||||
|
if pane_state.db and pane_state.db_enabled and pane_state.current_project:
|
||||||
|
data = await asyncio.to_thread(
|
||||||
|
pane_state.db.load_full_data, pane_state.current_project, file_stem)
|
||||||
|
if data is None:
|
||||||
|
data, _ = await asyncio.to_thread(load_json, fp)
|
||||||
|
if pane_state.db and pane_state.db_enabled and pane_state.current_project:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
sync_to_db, pane_state.db, pane_state.current_project, fp, data)
|
||||||
|
tree = data.get('history_tree')
|
||||||
|
if tree and isinstance(tree, dict):
|
||||||
|
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||||
|
entry.pop('data', None)
|
||||||
|
for backup in data.get('history_tree_backup', []):
|
||||||
|
if isinstance(backup, dict):
|
||||||
|
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||||
|
entry.pop('data', None)
|
||||||
pane_state.data_cache = data
|
pane_state.data_cache = data
|
||||||
pane_state.last_mtime = mtime
|
pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||||
pane_state.loaded_file = str(fp)
|
pane_state.loaded_file = str(fp)
|
||||||
pane_state.file_path = fp
|
pane_state.file_path = fp
|
||||||
pane_state.restored_indicator = None
|
pane_state.restored_indicator = None
|
||||||
|
pane_state._src_cache = {'data': None, 'batch': [], 'name': None}
|
||||||
_render_batch_tab_content.refresh()
|
_render_batch_tab_content.refresh()
|
||||||
|
logger.info("on_select END (%.3fs)", _time.perf_counter() - _t0)
|
||||||
|
|
||||||
ui.select(
|
ui.select(
|
||||||
file_names,
|
file_names,
|
||||||
@@ -274,19 +317,44 @@ def index():
|
|||||||
on_change=on_select,
|
on_change=on_select,
|
||||||
).classes('w-full')
|
).classes('w-full')
|
||||||
|
|
||||||
def load_file(file_name: str):
|
async def load_file(file_name: str):
|
||||||
"""Load a JSON file and refresh the main content."""
|
"""Load data from DB (fast) with JSON fallback, and refresh the main content."""
|
||||||
|
import time as _time
|
||||||
|
_t0 = _time.perf_counter()
|
||||||
|
logger.info("load_file START: %s", file_name)
|
||||||
fp = state.current_dir / file_name
|
fp = state.current_dir / file_name
|
||||||
if state.loaded_file == str(fp):
|
if state.loaded_file == str(fp):
|
||||||
return
|
return
|
||||||
data, mtime = load_json(fp)
|
file_stem = fp.stem
|
||||||
|
data = None
|
||||||
|
if state.db and state.db_enabled and state.current_project:
|
||||||
|
data = await asyncio.to_thread(
|
||||||
|
state.db.load_full_data, state.current_project, file_stem)
|
||||||
|
if data is None:
|
||||||
|
data, _ = await asyncio.to_thread(load_json, fp)
|
||||||
|
# When loading from JSON fallback and DB is enabled, sync to DB
|
||||||
|
# so snapshots are persisted, then strip from memory
|
||||||
|
if state.db and state.db_enabled and state.current_project:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
sync_to_db, state.db, state.current_project, fp, data)
|
||||||
|
tree = data.get('history_tree')
|
||||||
|
if tree and isinstance(tree, dict):
|
||||||
|
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||||
|
entry.pop('data', None)
|
||||||
|
# Strip snapshot data from history_tree_backup to prevent RAM/disk bloat
|
||||||
|
for backup in data.get('history_tree_backup', []):
|
||||||
|
if isinstance(backup, dict):
|
||||||
|
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||||
|
entry.pop('data', None)
|
||||||
state.data_cache = data
|
state.data_cache = data
|
||||||
state.last_mtime = mtime
|
state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||||
state.loaded_file = str(fp)
|
state.loaded_file = str(fp)
|
||||||
state.file_path = fp
|
state.file_path = fp
|
||||||
state.restored_indicator = None
|
state.restored_indicator = None
|
||||||
|
state._src_cache = {'data': None, 'batch': [], 'name': None}
|
||||||
if state._main_rendered:
|
if state._main_rendered:
|
||||||
render_main_content.refresh()
|
render_main_content.refresh()
|
||||||
|
logger.info("load_file END (%.3fs)", _time.perf_counter() - _t0)
|
||||||
|
|
||||||
# Attach helpers to state so sidebar can call them
|
# Attach helpers to state so sidebar can call them
|
||||||
state._load_file = load_file
|
state._load_file = load_file
|
||||||
@@ -460,16 +528,16 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
|||||||
with ui.expansion('Create New JSON'):
|
with ui.expansion('Create New JSON'):
|
||||||
new_fn_input = ui.input('Filename', placeholder='my_prompt_vace').classes('w-full')
|
new_fn_input = ui.input('Filename', placeholder='my_prompt_vace').classes('w-full')
|
||||||
|
|
||||||
def create_new():
|
async def create_new():
|
||||||
fn = new_fn_input.value
|
fn = new_fn_input.value
|
||||||
if not fn:
|
if not fn:
|
||||||
return
|
return
|
||||||
if not fn.endswith('.json'):
|
if not fn.endswith('.json'):
|
||||||
fn += '.json'
|
fn += '.json'
|
||||||
path = state.current_dir / fn
|
path = state.current_dir / fn
|
||||||
first_item = DEFAULTS.copy()
|
first_item = copy.deepcopy(DEFAULTS)
|
||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
first_item[KEY_SEQUENCE_NUMBER] = 1
|
||||||
save_json(path, {KEY_BATCH_DATA: [first_item]})
|
await asyncio.to_thread(save_json, path, {KEY_BATCH_DATA: [first_item]})
|
||||||
new_fn_input.set_value('')
|
new_fn_input.set_value('')
|
||||||
render_file_list.refresh()
|
render_file_list.refresh()
|
||||||
|
|
||||||
@@ -479,15 +547,19 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
|||||||
file_names = [f.name for f in json_files]
|
file_names = [f.name for f in json_files]
|
||||||
current = Path(state.loaded_file).name if state.loaded_file else None
|
current = Path(state.loaded_file).name if state.loaded_file else None
|
||||||
selected = current if current in file_names else (file_names[0] if file_names else None)
|
selected = current if current in file_names else (file_names[0] if file_names else None)
|
||||||
|
async def _on_radio(e):
|
||||||
|
if e.value:
|
||||||
|
await state._load_file(e.value)
|
||||||
|
|
||||||
ui.radio(
|
ui.radio(
|
||||||
file_names,
|
file_names,
|
||||||
value=selected,
|
value=selected,
|
||||||
on_change=lambda e: state._load_file(e.value) if e.value else None,
|
on_change=_on_radio,
|
||||||
).classes('w-full')
|
).classes('w-full')
|
||||||
|
|
||||||
# Auto-load first file if nothing loaded yet
|
# Auto-load first file if nothing loaded yet
|
||||||
if file_names and not state.loaded_file:
|
if file_names and not state.loaded_file:
|
||||||
state._load_file(file_names[0])
|
asyncio.ensure_future(state._load_file(file_names[0]))
|
||||||
|
|
||||||
def _gen_templates():
|
def _gen_templates():
|
||||||
generate_templates(state.current_dir)
|
generate_templates(state.current_dir)
|
||||||
@@ -500,11 +572,11 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
|||||||
state.show_comfy_monitor = e.value
|
state.show_comfy_monitor = e.value
|
||||||
state._render_main.refresh()
|
state._render_main.refresh()
|
||||||
|
|
||||||
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
ui.checkbox('Show Comfy Monitor', value=state.show_comfy_monitor, on_change=on_monitor_toggle)
|
||||||
|
|
||||||
|
|
||||||
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
||||||
if _shared_db is not None:
|
if _shared_db is not None:
|
||||||
register_api_routes(_shared_db)
|
register_api_routes(_shared_db)
|
||||||
|
|
||||||
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
ui.run(title='AI Settings Manager', port=8080, reload=False)
|
||||||
|
|||||||
+97
-5
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
@@ -88,7 +89,7 @@ if PromptServer is not None:
|
|||||||
async def list_projects_proxy(request):
|
async def list_projects_proxy(request):
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects"
|
url = f"{manager_url.rstrip('/')}/api/projects"
|
||||||
data = _fetch_json(url)
|
data = await asyncio.to_thread(_fetch_json, url)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
||||||
@@ -96,7 +97,7 @@ if PromptServer is not None:
|
|||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
||||||
data = _fetch_json(url)
|
data = await asyncio.to_thread(_fetch_json, url)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
||||||
@@ -105,7 +106,7 @@ if PromptServer is not None:
|
|||||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
||||||
data = _fetch_json(url)
|
data = await asyncio.to_thread(_fetch_json, url)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
|
@PromptServer.instance.routes.get("/json_manager/get_project_keys")
|
||||||
@@ -117,7 +118,7 @@ if PromptServer is not None:
|
|||||||
seq = int(request.query.get("seq", "1"))
|
seq = int(request.query.get("seq", "1"))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
seq = 1
|
seq = 1
|
||||||
data = _fetch_keys(manager_url, project, file_name, seq)
|
data = await asyncio.to_thread(_fetch_keys, manager_url, project, file_name, seq)
|
||||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
status = data.get("status", 502)
|
status = data.get("status", 502)
|
||||||
return web.json_response(data, status=status)
|
return web.json_response(data, status=status)
|
||||||
@@ -138,6 +139,7 @@ class ProjectLoaderDynamic:
|
|||||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
"refresh": (["off", "on"],),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"output_keys": ("STRING", {"default": ""}),
|
"output_keys": ("STRING", {"default": ""}),
|
||||||
@@ -152,7 +154,7 @@ class ProjectLoaderDynamic:
|
|||||||
OUTPUT_NODE = False
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
def load_dynamic(self, manager_url, project_name, file_name, sequence_number,
|
||||||
output_keys="", output_types=""):
|
refresh="off", output_keys="", output_types=""):
|
||||||
# Fetch keys metadata (includes total_sequences count)
|
# Fetch keys metadata (includes total_sequences count)
|
||||||
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
keys_meta = _fetch_keys(manager_url, project_name, file_name, sequence_number)
|
||||||
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
if keys_meta.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
@@ -205,11 +207,101 @@ class ProjectLoaderDynamic:
|
|||||||
return (total_sequences,) + tuple(results)
|
return (total_sequences,) + tuple(results)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectSource:
|
||||||
|
"""Config node — holds project connection settings, outputs sequence_number."""
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||||
|
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
"label": ("STRING", {"default": "source", "multiline": False}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("INT",)
|
||||||
|
RETURN_NAMES = ("sequence_number",)
|
||||||
|
FUNCTION = "hold_config"
|
||||||
|
CATEGORY = "utils/json/project"
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def hold_config(self, manager_url, project_name, file_name, sequence_number, label):
|
||||||
|
return (sequence_number,)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectKey:
|
||||||
|
"""Single-output relay — fetches one key from a ProjectSource."""
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"source_label": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"key_type": ("STRING", {"default": "STRING", "multiline": False}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||||
|
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||||
|
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (any_type,)
|
||||||
|
RETURN_NAMES = ("value",)
|
||||||
|
FUNCTION = "fetch_key"
|
||||||
|
CATEGORY = "utils/json/project"
|
||||||
|
OUTPUT_NODE = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, **kwargs):
|
||||||
|
return float("nan") # Always re-fetch from API
|
||||||
|
|
||||||
|
def fetch_key(self, source_label, key_name, key_type,
|
||||||
|
manager_url="http://localhost:8080", project_name="",
|
||||||
|
file_name="", sequence_number=1):
|
||||||
|
# source_label is used by JS to identify which ProjectSource to sync
|
||||||
|
# config from. The actual config arrives via the optional widgets below.
|
||||||
|
sequence_number = int(sequence_number)
|
||||||
|
logger.info("ProjectKey.fetch_key: source=%s key=%s url=%s project=%s file=%s seq=%s",
|
||||||
|
source_label, key_name, manager_url, project_name, file_name, sequence_number)
|
||||||
|
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||||
|
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||||
|
msg = data.get("message", "Unknown error")
|
||||||
|
logger.warning("ProjectKey.fetch_key failed: %s", msg)
|
||||||
|
# Return empty/default instead of crashing the workflow
|
||||||
|
if key_type == "INT":
|
||||||
|
return (0,)
|
||||||
|
elif key_type == "FLOAT":
|
||||||
|
return (0.0,)
|
||||||
|
else:
|
||||||
|
return ("",)
|
||||||
|
|
||||||
|
val = data.get(key_name, "")
|
||||||
|
|
||||||
|
if key_type == "INT":
|
||||||
|
return (to_int(val),)
|
||||||
|
elif key_type == "FLOAT":
|
||||||
|
return (to_float(val),)
|
||||||
|
elif isinstance(val, bool):
|
||||||
|
return (str(val).lower(),)
|
||||||
|
elif isinstance(val, (int, float)):
|
||||||
|
return (val,)
|
||||||
|
else:
|
||||||
|
return (str(val),)
|
||||||
|
|
||||||
|
|
||||||
# --- Mappings ---
|
# --- Mappings ---
|
||||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||||
|
"ProjectSource": ProjectSource,
|
||||||
|
"ProjectKey": ProjectKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||||
|
"ProjectSource": "Project Source",
|
||||||
|
"ProjectKey": "Project Key",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
KEY_PROMPT_HISTORY = "prompt_history"
|
||||||
|
|
||||||
|
|
||||||
|
class SnapshotTimeline:
|
||||||
|
"""Flat chronological snapshot list — replaces the old HistoryTree DAG."""
|
||||||
|
|
||||||
|
def __init__(self, raw_data: dict[str, Any]) -> None:
|
||||||
|
# Detect and migrate old HistoryTree format
|
||||||
|
if "nodes" in raw_data and "branches" in raw_data:
|
||||||
|
self._migrate_from_tree(raw_data)
|
||||||
|
elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list):
|
||||||
|
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
||||||
|
else:
|
||||||
|
self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {})
|
||||||
|
self.current_id: str | None = raw_data.get("current_id", None)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Migration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None:
|
||||||
|
"""Flatten old HistoryTree nodes into snapshot list, discarding DAG info."""
|
||||||
|
self.snapshots = {}
|
||||||
|
nodes = raw_data.get("nodes", {})
|
||||||
|
for nid, node in nodes.items():
|
||||||
|
self.snapshots[nid] = {
|
||||||
|
"id": nid,
|
||||||
|
"timestamp": node.get("timestamp", time.time()),
|
||||||
|
"note": node.get("note", "Migrated"),
|
||||||
|
"pinned": False,
|
||||||
|
"auto": False,
|
||||||
|
"seq_count": self._count_seqs(node.get("data")),
|
||||||
|
}
|
||||||
|
# Preserve snapshot data if present
|
||||||
|
if "data" in node and node["data"]:
|
||||||
|
self.snapshots[nid]["data"] = node["data"]
|
||||||
|
self.current_id = raw_data.get("head_id")
|
||||||
|
|
||||||
|
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||||
|
"""Convert ancient prompt_history list into snapshots."""
|
||||||
|
self.snapshots = {}
|
||||||
|
self.current_id = None
|
||||||
|
for item in reversed(old_list):
|
||||||
|
sid = self._make_id()
|
||||||
|
self.snapshots[sid] = {
|
||||||
|
"id": sid,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"note": item.get("note", "Legacy Import"),
|
||||||
|
"pinned": False,
|
||||||
|
"auto": False,
|
||||||
|
"seq_count": self._count_seqs(item),
|
||||||
|
"data": item,
|
||||||
|
}
|
||||||
|
self.current_id = sid
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Core operations
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def record(self, data: dict[str, Any], note: str = "Snapshot",
|
||||||
|
auto: bool = False) -> str:
|
||||||
|
"""Create a new snapshot and return its ID."""
|
||||||
|
sid = self._make_id()
|
||||||
|
self.snapshots[sid] = {
|
||||||
|
"id": sid,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"note": note,
|
||||||
|
"pinned": False,
|
||||||
|
"auto": auto,
|
||||||
|
"seq_count": self._count_seqs(data),
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
self.current_id = sid
|
||||||
|
return sid
|
||||||
|
|
||||||
|
def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return the inline snapshot data if present."""
|
||||||
|
snap = self.snapshots.get(snapshot_id)
|
||||||
|
if snap:
|
||||||
|
return snap.get("data")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def toggle_pin(self, snapshot_id: str) -> bool:
|
||||||
|
"""Toggle pinned state, return new value."""
|
||||||
|
snap = self.snapshots.get(snapshot_id)
|
||||||
|
if snap:
|
||||||
|
snap["pinned"] = not snap.get("pinned", False)
|
||||||
|
return snap["pinned"]
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, snapshot_id: str) -> None:
|
||||||
|
"""Remove a snapshot."""
|
||||||
|
self.snapshots.pop(snapshot_id, None)
|
||||||
|
if self.current_id == snapshot_id:
|
||||||
|
# Fall back to most recent remaining
|
||||||
|
if self.snapshots:
|
||||||
|
self.current_id = max(
|
||||||
|
self.snapshots.values(), key=lambda s: s["timestamp"]
|
||||||
|
)["id"]
|
||||||
|
else:
|
||||||
|
self.current_id = None
|
||||||
|
|
||||||
|
def strip_snapshots(self) -> None:
|
||||||
|
"""Remove inline data from all snapshots (for slim JSON storage)."""
|
||||||
|
for snap in self.snapshots.values():
|
||||||
|
snap.pop("data", None)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Serialization
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"snapshots": self.snapshots,
|
||||||
|
"current_id": self.current_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_id(self) -> str:
|
||||||
|
for _ in range(10):
|
||||||
|
sid = str(uuid.uuid4())[:8]
|
||||||
|
if sid not in self.snapshots:
|
||||||
|
return sid
|
||||||
|
raise ValueError("Failed to generate unique snapshot ID after 10 attempts")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _count_seqs(data: dict | None) -> int:
|
||||||
|
if not data:
|
||||||
|
return 0
|
||||||
|
from utils import KEY_BATCH_DATA
|
||||||
|
batch = data.get(KEY_BATCH_DATA, [])
|
||||||
|
return len(batch) if isinstance(batch, list) else 0
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Diff function
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]:
|
||||||
|
"""Compare two batch lists by sequence_number, return per-sequence diffs.
|
||||||
|
|
||||||
|
Returns a list of dicts:
|
||||||
|
{
|
||||||
|
"seq_num": int,
|
||||||
|
"status": "unchanged" | "changed" | "added" | "removed",
|
||||||
|
"changes": [{"field": str, "old": Any, "new": Any}],
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
from utils import KEY_SEQUENCE_NUMBER
|
||||||
|
|
||||||
|
old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch}
|
||||||
|
new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch}
|
||||||
|
|
||||||
|
all_seqs = sorted(set(old_by_seq) | set(new_by_seq))
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for seq_num in all_seqs:
|
||||||
|
old_item = old_by_seq.get(seq_num)
|
||||||
|
new_item = new_by_seq.get(seq_num)
|
||||||
|
|
||||||
|
if old_item and not new_item:
|
||||||
|
result.append({"seq_num": seq_num, "status": "removed", "changes": []})
|
||||||
|
elif new_item and not old_item:
|
||||||
|
result.append({"seq_num": seq_num, "status": "added", "changes": []})
|
||||||
|
else:
|
||||||
|
# Both exist — field-by-field comparison
|
||||||
|
all_keys = sorted(set(old_item) | set(new_item))
|
||||||
|
changes = []
|
||||||
|
for k in all_keys:
|
||||||
|
old_val = old_item.get(k)
|
||||||
|
new_val = new_item.get(k)
|
||||||
|
if old_val != new_val:
|
||||||
|
changes.append({"field": k, "old": old_val, "new": new_val})
|
||||||
|
status = "changed" if changes else "unchanged"
|
||||||
|
result.append({"seq_num": seq_num, "status": status, "changes": changes})
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -13,7 +13,7 @@ class AppState:
|
|||||||
snippets: dict = field(default_factory=dict)
|
snippets: dict = field(default_factory=dict)
|
||||||
file_path: Path | None = None
|
file_path: Path | None = None
|
||||||
restored_indicator: str | None = None
|
restored_indicator: str | None = None
|
||||||
timeline_selected_nodes: set = field(default_factory=set)
|
timeline_selected_id: str | None = None
|
||||||
live_toggles: dict = field(default_factory=dict)
|
live_toggles: dict = field(default_factory=dict)
|
||||||
show_comfy_monitor: bool = True
|
show_comfy_monitor: bool = True
|
||||||
|
|
||||||
@@ -28,6 +28,7 @@ class AppState:
|
|||||||
_main_rendered: bool = False
|
_main_rendered: bool = False
|
||||||
_live_checkboxes: dict = field(default_factory=dict)
|
_live_checkboxes: dict = field(default_factory=dict)
|
||||||
_live_refreshables: dict = field(default_factory=dict)
|
_live_refreshables: dict = field(default_factory=dict)
|
||||||
|
_src_cache: dict = field(default_factory=lambda: {'data': None, 'batch': [], 'name': None})
|
||||||
|
|
||||||
def create_secondary(self) -> 'AppState':
|
def create_secondary(self) -> 'AppState':
|
||||||
return AppState(
|
return AppState(
|
||||||
|
|||||||
+236
-76
@@ -1,18 +1,28 @@
|
|||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from nicegui import ui
|
from nicegui import ui
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from state import AppState
|
from state import AppState
|
||||||
from utils import (
|
from utils import (
|
||||||
DEFAULTS, save_json, load_json, sync_to_db,
|
DEFAULTS, save_json, load_json, sync_to_db,
|
||||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
||||||
)
|
)
|
||||||
from history_tree import HistoryTree
|
from snapshot_timeline import SnapshotTimeline
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
|
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
|
||||||
|
_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots
|
||||||
|
_last_auto_snap: dict[str, float] = {} # file_path -> timestamp
|
||||||
SUB_SEGMENT_MULTIPLIER = 1000
|
SUB_SEGMENT_MULTIPLIER = 1000
|
||||||
|
SUB_SEGMENT_NUM_COLORS = 6
|
||||||
FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip']
|
FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip']
|
||||||
|
|
||||||
VACE_MODES = [
|
VACE_MODES = [
|
||||||
@@ -76,6 +86,53 @@ def find_insert_position(batch_list, parent_index, parent_seq_num):
|
|||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
# --- Auto change note ---
|
||||||
|
|
||||||
|
def _auto_change_note(timeline, batch_list, state=None, file_path=None):
|
||||||
|
"""Compare current batch_list against last snapshot and describe changes."""
|
||||||
|
# Get previous batch data from the current snapshot
|
||||||
|
if not timeline.current_id or timeline.current_id not in timeline.snapshots:
|
||||||
|
return f'Initial save ({len(batch_list)} sequences)'
|
||||||
|
|
||||||
|
# Load previous snapshot from inline data or DB
|
||||||
|
prev_data = timeline.get_snapshot_data(timeline.current_id)
|
||||||
|
if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path:
|
||||||
|
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
||||||
|
if df:
|
||||||
|
prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id)
|
||||||
|
prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, [])
|
||||||
|
|
||||||
|
prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch}
|
||||||
|
curr_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in batch_list}
|
||||||
|
|
||||||
|
added = sorted(set(curr_by_seq) - set(prev_by_seq))
|
||||||
|
removed = sorted(set(prev_by_seq) - set(curr_by_seq))
|
||||||
|
|
||||||
|
changed_keys = set()
|
||||||
|
for seq_num in sorted(set(curr_by_seq) & set(prev_by_seq)):
|
||||||
|
old, new = prev_by_seq[seq_num], curr_by_seq[seq_num]
|
||||||
|
all_keys = set(old) | set(new)
|
||||||
|
for k in all_keys:
|
||||||
|
if old.get(k) != new.get(k):
|
||||||
|
changed_keys.add(k)
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if added:
|
||||||
|
parts.append(f'Added seq {", ".join(str(s) for s in added)}')
|
||||||
|
if removed:
|
||||||
|
parts.append(f'Removed seq {", ".join(str(s) for s in removed)}')
|
||||||
|
if changed_keys:
|
||||||
|
# Show up to 4 changed field names
|
||||||
|
keys_list = sorted(changed_keys)
|
||||||
|
if len(keys_list) > 4:
|
||||||
|
keys_str = ', '.join(keys_list[:4]) + f' +{len(keys_list) - 4} more'
|
||||||
|
else:
|
||||||
|
keys_str = ', '.join(keys_list)
|
||||||
|
parts.append(f'Changed: {keys_str}')
|
||||||
|
|
||||||
|
return '; '.join(parts) if parts else 'No changes detected'
|
||||||
|
|
||||||
|
|
||||||
# --- Helper for repetitive dict-bound inputs ---
|
# --- Helper for repetitive dict-bound inputs ---
|
||||||
|
|
||||||
def dict_input(element_fn, label, seq, key, **kwargs):
|
def dict_input(element_fn, label, seq, key, **kwargs):
|
||||||
@@ -99,6 +156,8 @@ def dict_number(label, seq, key, default=0, **kwargs):
|
|||||||
try:
|
try:
|
||||||
# Try float first to handle "1.5" strings, then check if it's a clean int
|
# Try float first to handle "1.5" strings, then check if it's a clean int
|
||||||
fval = float(val)
|
fval = float(val)
|
||||||
|
if not math.isfinite(fval):
|
||||||
|
fval = float(default)
|
||||||
val = int(fval) if fval == int(fval) else fval
|
val = int(fval) if fval == int(fval) else fval
|
||||||
except (ValueError, TypeError, OverflowError):
|
except (ValueError, TypeError, OverflowError):
|
||||||
val = default
|
val = default
|
||||||
@@ -109,6 +168,9 @@ def dict_number(label, seq, key, default=0, **kwargs):
|
|||||||
if v is None:
|
if v is None:
|
||||||
v = d
|
v = d
|
||||||
elif isinstance(v, float):
|
elif isinstance(v, float):
|
||||||
|
if not math.isfinite(v):
|
||||||
|
v = d
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
v = int(v) if v == int(v) else v
|
v = int(v) if v == int(v) else v
|
||||||
except (OverflowError, ValueError):
|
except (OverflowError, ValueError):
|
||||||
@@ -137,6 +199,8 @@ def dict_textarea(label, seq, key, **kwargs):
|
|||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
||||||
def render_batch_processor(state: AppState):
|
def render_batch_processor(state: AppState):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
logger.info("render_batch_processor START")
|
||||||
data = state.data_cache
|
data = state.data_cache
|
||||||
file_path = state.file_path
|
file_path = state.file_path
|
||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
@@ -148,7 +212,7 @@ def render_batch_processor(state: AppState):
|
|||||||
ui.label('This is a Single file. To use Batch mode, create a copy.').classes(
|
ui.label('This is a Single file. To use Batch mode, create a copy.').classes(
|
||||||
'text-warning')
|
'text-warning')
|
||||||
|
|
||||||
def create_batch():
|
async def create_batch():
|
||||||
new_name = f'batch_{file_path.name}'
|
new_name = f'batch_{file_path.name}'
|
||||||
new_path = file_path.parent / new_name
|
new_path = file_path.parent / new_name
|
||||||
if new_path.exists():
|
if new_path.exists():
|
||||||
@@ -160,9 +224,9 @@ def render_batch_processor(state: AppState):
|
|||||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
first_item[KEY_SEQUENCE_NUMBER] = 1
|
||||||
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
new_data = {KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {},
|
||||||
KEY_PROMPT_HISTORY: []}
|
KEY_PROMPT_HISTORY: []}
|
||||||
save_json(new_path, new_data)
|
await asyncio.to_thread(save_json, new_path, new_data)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, new_path, new_data)
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, new_path, new_data)
|
||||||
ui.notify(f'Created {new_name}', type='positive')
|
ui.notify(f'Created {new_name}', type='positive')
|
||||||
|
|
||||||
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
ui.button('Create Batch Copy', icon='content_copy', on_click=create_batch)
|
||||||
@@ -190,12 +254,16 @@ def render_batch_processor(state: AppState):
|
|||||||
|
|
||||||
src_seq_select = ui.select([], label='Source Sequence:').classes('w-64')
|
src_seq_select = ui.select([], label='Source Sequence:').classes('w-64')
|
||||||
|
|
||||||
# Track loaded source data
|
# Track loaded source data (on state so it's cleared on file switch)
|
||||||
_src_cache = {'data': None, 'batch': [], 'name': None}
|
_src_cache = state._src_cache
|
||||||
|
|
||||||
def _update_src():
|
def _update_src():
|
||||||
name = src_file_select.value
|
name = src_file_select.value
|
||||||
if name and name != _src_cache['name']:
|
if name and name != _src_cache['name']:
|
||||||
|
# Reuse current data if source is the same file
|
||||||
|
if name == file_path.name:
|
||||||
|
src_data = data
|
||||||
|
else:
|
||||||
src_data, _ = load_json(state.current_dir / name)
|
src_data, _ = load_json(state.current_dir / name)
|
||||||
_src_cache['data'] = src_data
|
_src_cache['data'] = src_data
|
||||||
_src_cache['batch'] = src_data.get(KEY_BATCH_DATA, [])
|
_src_cache['batch'] = src_data.get(KEY_BATCH_DATA, [])
|
||||||
@@ -210,39 +278,41 @@ def render_batch_processor(state: AppState):
|
|||||||
src_file_select.on_value_change(lambda _: _update_src())
|
src_file_select.on_value_change(lambda _: _update_src())
|
||||||
_update_src()
|
_update_src()
|
||||||
|
|
||||||
def _add_sequence(new_item):
|
async def _add_sequence(new_item):
|
||||||
new_item[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
|
new_item[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
|
||||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, 'note', 'loras']:
|
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, 'note', 'loras']:
|
||||||
new_item.pop(k, None)
|
new_item.pop(k, None)
|
||||||
batch_list.append(new_item)
|
batch_list.append(new_item)
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||||
render_sequence_list.refresh()
|
render_sequence_list.refresh()
|
||||||
|
|
||||||
with ui.row().classes('q-mt-sm'):
|
with ui.row().classes('q-mt-sm'):
|
||||||
def add_empty():
|
async def add_empty():
|
||||||
_add_sequence(DEFAULTS.copy())
|
await _add_sequence(copy.deepcopy(DEFAULTS))
|
||||||
|
|
||||||
def add_from_source():
|
async def add_from_source():
|
||||||
item = copy.deepcopy(DEFAULTS)
|
item = copy.deepcopy(DEFAULTS)
|
||||||
src_batch = _src_cache['batch']
|
src_batch = _src_cache['batch']
|
||||||
sel_idx = src_seq_select.value
|
sel_idx = src_seq_select.value
|
||||||
if src_batch and sel_idx is not None:
|
if src_batch and sel_idx is not None and int(sel_idx) < len(src_batch):
|
||||||
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
||||||
elif _src_cache['data']:
|
elif _src_cache['data']:
|
||||||
item.update(copy.deepcopy(_src_cache['data']))
|
item.update(copy.deepcopy(_src_cache['data']))
|
||||||
_add_sequence(item)
|
await _add_sequence(item)
|
||||||
|
|
||||||
ui.button('Add Empty', icon='add', on_click=add_empty)
|
ui.button('Add Empty', icon='add', on_click=add_empty)
|
||||||
ui.button('From Source', icon='file_download', on_click=add_from_source)
|
ui.button('From Source', icon='file_download', on_click=add_from_source)
|
||||||
|
|
||||||
# --- Standard / LoRA / VACE key sets ---
|
# --- Standard / LoRA / VACE key sets ---
|
||||||
lora_keys = ['lora 1 high', 'lora 1 low', 'lora 2 high', 'lora 2 low',
|
lora_keys = ['lora 1 high', 'lora 1 high strength', 'lora 1 low', 'lora 1 low strength',
|
||||||
'lora 3 high', 'lora 3 low']
|
'lora 2 high', 'lora 2 high strength', 'lora 2 low', 'lora 2 low strength',
|
||||||
|
'lora 3 high', 'lora 3 high strength', 'lora 3 low', 'lora 3 low strength']
|
||||||
standard_keys = {
|
standard_keys = {
|
||||||
'general_prompt', 'general_negative', 'current_prompt', 'negative', 'prompt',
|
'name', 'mode', 'general_prompt', 'general_negative', 'current_prompt', 'negative', 'prompt',
|
||||||
'seed', 'cfg', 'camera', 'flf', KEY_SEQUENCE_NUMBER,
|
'seed', 'cfg', 'camera', 'flf', KEY_SEQUENCE_NUMBER,
|
||||||
'frame_to_skip', 'end_frame', 'transition', 'vace_length',
|
'frame_to_skip', 'end_frame', 'transition', 'vace_length',
|
||||||
'input_a_frames', 'input_b_frames', 'reference switch', 'vace schedule',
|
'input_a_frames', 'input_b_frames', 'reference switch', 'vace schedule',
|
||||||
@@ -250,18 +320,21 @@ def render_batch_processor(state: AppState):
|
|||||||
}
|
}
|
||||||
standard_keys.update(lora_keys)
|
standard_keys.update(lora_keys)
|
||||||
|
|
||||||
def sort_by_number():
|
async def sort_by_number():
|
||||||
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||||
ui.notify('Sorted by sequence number!', type='positive')
|
ui.notify('Sorted by sequence number!', type='positive')
|
||||||
render_sequence_list.refresh()
|
render_sequence_list.refresh()
|
||||||
|
|
||||||
# --- Sequence list + mass update (inside refreshable so they stay in sync) ---
|
# --- Sequence list + mass update (inside refreshable so they stay in sync) ---
|
||||||
@ui.refreshable
|
@ui.refreshable
|
||||||
def render_sequence_list():
|
def render_sequence_list():
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
logger.info("render_sequence_list START (%d sequences)", len(batch_list))
|
||||||
# Mass update (rebuilt on refresh so checkboxes match current sequences)
|
# Mass update (rebuilt on refresh so checkboxes match current sequences)
|
||||||
_render_mass_update(batch_list, data, file_path, state, render_sequence_list)
|
_render_mass_update(batch_list, data, file_path, state, render_sequence_list)
|
||||||
|
|
||||||
@@ -276,8 +349,10 @@ def render_batch_processor(state: AppState):
|
|||||||
_src_cache, src_seq_select,
|
_src_cache, src_seq_select,
|
||||||
standard_keys, render_sequence_list,
|
standard_keys, render_sequence_list,
|
||||||
)
|
)
|
||||||
|
logger.info("render_sequence_list END (%.3fs)", time.perf_counter() - t1)
|
||||||
|
|
||||||
render_sequence_list()
|
render_sequence_list()
|
||||||
|
logger.info("render_batch_processor END (%.3fs)", time.perf_counter() - t0)
|
||||||
|
|
||||||
# --- Save & Snap ---
|
# --- Save & Snap ---
|
||||||
with ui.card().classes('w-full q-pa-md q-mt-lg'):
|
with ui.card().classes('w-full q-pa-md q-mt-lg'):
|
||||||
@@ -285,20 +360,46 @@ def render_batch_processor(state: AppState):
|
|||||||
commit_input = ui.input('Change Note (Optional)',
|
commit_input = ui.input('Change Note (Optional)',
|
||||||
placeholder='e.g. Added sequence 3').classes('col')
|
placeholder='e.g. Added sequence 3').classes('col')
|
||||||
|
|
||||||
def save_and_snap():
|
async def save_and_snap():
|
||||||
|
t_ss = time.perf_counter()
|
||||||
|
logger.info("save_and_snap START")
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||||
htree = HistoryTree(tree_data)
|
timeline = SnapshotTimeline(tree_data)
|
||||||
snapshot_payload = copy.deepcopy(data)
|
note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path)
|
||||||
snapshot_payload.pop(KEY_HISTORY_TREE, None)
|
# Single serialization: json roundtrip gives us an isolated snapshot
|
||||||
note = commit_input.value if commit_input.value else 'Batch Update'
|
t1 = time.perf_counter()
|
||||||
htree.commit(snapshot_payload, note=note)
|
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
if k != KEY_HISTORY_TREE})
|
||||||
save_json(file_path, data)
|
snapshot_payload = json.loads(snapshot_json)
|
||||||
|
logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1)
|
||||||
|
try:
|
||||||
|
timeline.record(snapshot_payload, note=note)
|
||||||
|
except ValueError as e:
|
||||||
|
ui.notify(f'Save failed: {e}', type='negative')
|
||||||
|
return
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
full_tree = timeline.to_dict()
|
||||||
|
data[KEY_HISTORY_TREE] = full_tree
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
db_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||||
|
logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1)
|
||||||
|
timeline.strip_snapshots()
|
||||||
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
slim_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||||
|
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
||||||
|
else:
|
||||||
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
save_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||||
|
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
||||||
state.restored_indicator = None
|
state.restored_indicator = None
|
||||||
commit_input.set_value('')
|
commit_input.set_value('')
|
||||||
|
logger.info("save_and_snap END (%.3fs)", time.perf_counter() - t_ss)
|
||||||
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
ui.notify('Batch Saved & Snapshot Created!', type='positive')
|
||||||
|
|
||||||
ui.button('Save & Snap', icon='save', on_click=save_and_snap).props('color=primary')
|
ui.button('Save & Snap', icon='save', on_click=save_and_snap).props('color=primary')
|
||||||
@@ -311,31 +412,72 @@ def render_batch_processor(state: AppState):
|
|||||||
def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||||
src_cache, src_seq_select, standard_keys,
|
src_cache, src_seq_select, standard_keys,
|
||||||
refresh_list):
|
refresh_list):
|
||||||
def commit(message=None):
|
async def commit(message=None):
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
# Auto-snapshot with debounce
|
||||||
|
fp_key = str(file_path)
|
||||||
|
now = time.time()
|
||||||
|
did_snap = False
|
||||||
|
if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE:
|
||||||
|
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||||
|
snap_json = json.dumps({k: v for k, v in data.items()
|
||||||
|
if k != KEY_HISTORY_TREE})
|
||||||
|
snap_payload = json.loads(snap_json)
|
||||||
|
try:
|
||||||
|
timeline.record(snap_payload, note=message or "Auto-save", auto=True)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
db_snap = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap)
|
||||||
|
timeline.strip_snapshots()
|
||||||
|
did_snap = True
|
||||||
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
_last_auto_snap[fp_key] = now
|
||||||
|
except ValueError:
|
||||||
|
pass # Non-critical: skip auto-snapshot on ID collision
|
||||||
|
snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||||
|
if state.db_enabled and state.current_project and state.db and not did_snap:
|
||||||
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||||
if message:
|
if message:
|
||||||
ui.notify(message, type='positive')
|
ui.notify(message, type='positive')
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
|
|
||||||
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i + 1)
|
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i + 1)
|
||||||
|
seq_name = seq.get('name', '')
|
||||||
|
|
||||||
if is_subsegment(seq_num):
|
if is_subsegment(seq_num):
|
||||||
label = f'Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)} ({int(seq_num)})'
|
label = f'Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)} ({int(seq_num)})'
|
||||||
else:
|
else:
|
||||||
label = f'Sequence #{seq_num}'
|
label = f'Sequence #{seq_num}'
|
||||||
|
if seq_name:
|
||||||
|
label += f' — {seq_name}'
|
||||||
|
|
||||||
with ui.expansion(label, icon='movie').classes('w-full'):
|
if is_subsegment(seq_num):
|
||||||
|
color_idx = (sub_index_of(seq_num) - 1) % SUB_SEGMENT_NUM_COLORS
|
||||||
|
exp_classes = f'w-full subsegment-color-{color_idx}'
|
||||||
|
else:
|
||||||
|
exp_classes = 'w-full'
|
||||||
|
with ui.expansion(label, icon='movie').classes(exp_classes) as expansion:
|
||||||
# --- Action row ---
|
# --- Action row ---
|
||||||
with ui.row().classes('w-full q-gutter-sm action-row'):
|
with ui.row().classes('w-full q-gutter-sm action-row'):
|
||||||
|
# Rename
|
||||||
|
async def rename(s=seq):
|
||||||
|
result = await ui.run_javascript(
|
||||||
|
f'prompt("Rename sequence:", {json.dumps(s.get("name", ""))})',
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
s['name'] = result
|
||||||
|
commit('Renamed!')
|
||||||
|
|
||||||
|
ui.button('Rename', icon='edit', on_click=rename).props('outline')
|
||||||
# Copy from source
|
# Copy from source
|
||||||
def copy_source(idx=i, sn=seq_num):
|
def copy_source(idx=i, sn=seq_num):
|
||||||
item = copy.deepcopy(DEFAULTS)
|
item = copy.deepcopy(DEFAULTS)
|
||||||
src_batch = src_cache['batch']
|
src_batch = src_cache['batch']
|
||||||
sel_idx = src_seq_select.value
|
sel_idx = src_seq_select.value
|
||||||
if src_batch and sel_idx is not None:
|
if src_batch and sel_idx is not None and int(sel_idx) < len(src_batch):
|
||||||
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
item.update(copy.deepcopy(src_batch[int(sel_idx)]))
|
||||||
elif src_cache['data']:
|
elif src_cache['data']:
|
||||||
item.update(copy.deepcopy(src_cache['data']))
|
item.update(copy.deepcopy(src_cache['data']))
|
||||||
@@ -390,6 +532,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
|
|
||||||
# Delete
|
# Delete
|
||||||
def delete(idx=i):
|
def delete(idx=i):
|
||||||
|
if idx < len(batch_list):
|
||||||
batch_list.pop(idx)
|
batch_list.pop(idx)
|
||||||
commit()
|
commit()
|
||||||
|
|
||||||
@@ -410,6 +553,9 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
'w-full q-mt-sm').props('outlined rows=2')
|
'w-full q-mt-sm').props('outlined rows=2')
|
||||||
|
|
||||||
with splitter.after:
|
with splitter.after:
|
||||||
|
# Mode
|
||||||
|
dict_number('Mode', seq, 'mode').props('outlined').classes('w-full')
|
||||||
|
|
||||||
# Sequence number
|
# Sequence number
|
||||||
sn_label = (
|
sn_label = (
|
||||||
f'Seq Number (Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)})'
|
f'Seq Number (Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)})'
|
||||||
@@ -463,20 +609,14 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
||||||
for lora_idx in range(1, 4):
|
for lora_idx in range(1, 4):
|
||||||
for tier, tier_label in [('high', 'High'), ('low', 'Low')]:
|
for tier, tier_label in [('high', 'High'), ('low', 'Low')]:
|
||||||
k = f'lora {lora_idx} {tier}'
|
lora_key = f'lora {lora_idx} {tier}'
|
||||||
raw = str(seq.get(k, ''))
|
|
||||||
inner = raw.replace('<lora:', '').replace('>', '')
|
lora_name = str(seq.get(lora_key, ''))
|
||||||
# Split "name:strength" or just "name"
|
strength_key = f'lora {lora_idx} {tier} strength'
|
||||||
if ':' in inner:
|
lora_strength = seq.get(strength_key, 1.0)
|
||||||
parts = inner.rsplit(':', 1)
|
|
||||||
lora_name = parts[0]
|
|
||||||
try:
|
try:
|
||||||
lora_strength = float(parts[1])
|
lora_strength = float(lora_strength)
|
||||||
except ValueError:
|
except (ValueError, TypeError):
|
||||||
lora_name = inner
|
|
||||||
lora_strength = 1.0
|
|
||||||
else:
|
|
||||||
lora_name = inner
|
|
||||||
lora_strength = 1.0
|
lora_strength = 1.0
|
||||||
|
|
||||||
with ui.row().classes('w-full items-center q-gutter-sm'):
|
with ui.row().classes('w-full items-center q-gutter-sm'):
|
||||||
@@ -493,10 +633,9 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
format='%.1f',
|
format='%.1f',
|
||||||
).props('outlined dense').style('max-width: 80px')
|
).props('outlined dense').style('max-width: 80px')
|
||||||
|
|
||||||
def _lora_sync(key=k, n_inp=name_input, s_inp=strength_input):
|
def _lora_sync(k=lora_key, sk=strength_key, n_inp=name_input, s_inp=strength_input):
|
||||||
name = n_inp.value or ''
|
seq[k] = n_inp.value or ''
|
||||||
strength = s_inp.value if s_inp.value is not None else 1.0
|
seq[sk] = float(s_inp.value) if s_inp.value is not None else 1.0
|
||||||
seq[key] = f'<lora:{name}:{strength:.1f}>' if name else ''
|
|
||||||
|
|
||||||
name_input.on('blur', lambda _, s=_lora_sync: s())
|
name_input.on('blur', lambda _, s=_lora_sync: s())
|
||||||
name_input.on('update:model-value', lambda _, s=_lora_sync: s())
|
name_input.on('update:model-value', lambda _, s=_lora_sync: s())
|
||||||
@@ -541,7 +680,13 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
|
|
||||||
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
||||||
# VACE Schedule (needed early for both columns)
|
# VACE Schedule (needed early for both columns)
|
||||||
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
def _safe_int(val, default=0):
|
||||||
|
try:
|
||||||
|
return int(float(val))
|
||||||
|
except (ValueError, TypeError, OverflowError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
sched_val = max(0, min(_safe_int(seq.get('vace schedule', 1), 1), len(VACE_MODES) - 1))
|
||||||
|
|
||||||
# Mode reference dialog
|
# Mode reference dialog
|
||||||
with ui.dialog() as ref_dlg, ui.card():
|
with ui.dialog() as ref_dlg, ui.card():
|
||||||
@@ -562,23 +707,24 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
|
|||||||
fts_input = dict_number('Frame to Skip', seq, 'frame_to_skip').classes(
|
fts_input = dict_number('Frame to Skip', seq, 'frame_to_skip').classes(
|
||||||
'col').props('outlined')
|
'col').props('outlined')
|
||||||
|
|
||||||
_original_fts = int(seq.get('frame_to_skip', FRAME_TO_SKIP_DEFAULT))
|
_original_fts = _safe_int(seq.get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT)
|
||||||
|
|
||||||
def shift_fts(idx=i, orig=_original_fts):
|
async def shift_fts(idx=i, orig=_original_fts):
|
||||||
new_fts = int(fts_input.value) if fts_input.value is not None else orig
|
new_fts = _safe_int(fts_input.value, orig)
|
||||||
delta = new_fts - orig
|
delta = new_fts - orig
|
||||||
if delta == 0:
|
if delta == 0:
|
||||||
ui.notify('No change to shift', type='info')
|
ui.notify('No change to shift', type='info')
|
||||||
return
|
return
|
||||||
shifted = 0
|
shifted = 0
|
||||||
for j in range(idx + 1, len(batch_list)):
|
for j in range(idx + 1, len(batch_list)):
|
||||||
batch_list[j]['frame_to_skip'] = int(
|
batch_list[j]['frame_to_skip'] = _safe_int(
|
||||||
batch_list[j].get('frame_to_skip', FRAME_TO_SKIP_DEFAULT)) + delta
|
batch_list[j].get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT) + delta
|
||||||
shifted += 1
|
shifted += 1
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||||
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
|
|
||||||
@@ -597,7 +743,7 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
|
|||||||
ui.button(icon='help', on_click=ref_dlg.open).props('flat dense round')
|
ui.button(icon='help', on_click=ref_dlg.open).props('flat dense round')
|
||||||
|
|
||||||
def update_mode_label(e):
|
def update_mode_label(e):
|
||||||
idx = int(e.sender.value) if e.sender.value is not None else 0
|
idx = _safe_int(e.sender.value, 0)
|
||||||
idx = max(0, min(idx, len(VACE_MODES) - 1))
|
idx = max(0, min(idx, len(VACE_MODES) - 1))
|
||||||
mode_label.set_text(VACE_MODES[idx])
|
mode_label.set_text(VACE_MODES[idx])
|
||||||
|
|
||||||
@@ -611,10 +757,10 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
|
|||||||
'outlined').classes('w-full q-mt-sm')
|
'outlined').classes('w-full q-mt-sm')
|
||||||
|
|
||||||
# VACE Length + output calculation
|
# VACE Length + output calculation
|
||||||
input_a = int(seq.get('input_a_frames', 16))
|
input_a = _safe_int(seq.get('input_a_frames', 16), 16)
|
||||||
input_b = int(seq.get('input_b_frames', 16))
|
input_b = _safe_int(seq.get('input_b_frames', 16), 16)
|
||||||
stored_total = int(seq.get('vace_length', 49))
|
stored_total = _safe_int(seq.get('vace_length', 49), 49)
|
||||||
mode_idx = int(seq.get('vace schedule', 1))
|
mode_idx = _safe_int(seq.get('vace schedule', 1), 1)
|
||||||
|
|
||||||
if mode_idx == 0:
|
if mode_idx == 0:
|
||||||
base_length = max(stored_total - input_a, 1)
|
base_length = max(stored_total - input_a, 1)
|
||||||
@@ -633,10 +779,10 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
|
|||||||
|
|
||||||
# Recalculate VACE output when any input changes
|
# Recalculate VACE output when any input changes
|
||||||
def recalc_vace(*_args):
|
def recalc_vace(*_args):
|
||||||
mi = int(vs_input.value) if vs_input.value is not None else 0
|
mi = _safe_int(vs_input.value, 0)
|
||||||
ia = int(ia_input.value) if ia_input.value is not None else 16
|
ia = _safe_int(ia_input.value, 16)
|
||||||
ib = int(ib_input.value) if ib_input.value is not None else 16
|
ib = _safe_int(ib_input.value, 16)
|
||||||
nb = int(vl_input.value) if vl_input.value is not None else 1
|
nb = _safe_int(vl_input.value, 1)
|
||||||
|
|
||||||
if mi == 0:
|
if mi == 0:
|
||||||
raw = nb + ia
|
raw = nb + ia
|
||||||
@@ -696,7 +842,7 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
|||||||
|
|
||||||
select_all_cb.on_value_change(on_select_all)
|
select_all_cb.on_value_change(on_select_all)
|
||||||
|
|
||||||
def apply_mass_update():
|
async def apply_mass_update():
|
||||||
src_idx = source_select.value
|
src_idx = source_select.value
|
||||||
if src_idx is None or src_idx >= len(batch_list):
|
if src_idx is None or src_idx >= len(batch_list):
|
||||||
ui.notify('Source sequence no longer exists', type='warning')
|
ui.notify('Source sequence no longer exists', type='warning')
|
||||||
@@ -718,14 +864,28 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
|||||||
batch_list[idx][key] = copy.deepcopy(source_seq.get(key))
|
batch_list[idx][key] = copy.deepcopy(source_seq.get(key))
|
||||||
|
|
||||||
data[KEY_BATCH_DATA] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
htree = HistoryTree(data.get(KEY_HISTORY_TREE, {}))
|
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||||
snapshot = copy.deepcopy(data)
|
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||||
snapshot.pop(KEY_HISTORY_TREE, None)
|
if k != KEY_HISTORY_TREE})
|
||||||
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
snapshot = json.loads(snapshot_json)
|
||||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
try:
|
||||||
save_json(file_path, data)
|
timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||||
|
except ValueError as e:
|
||||||
|
ui.notify(f'Mass update failed: {e}', type='negative')
|
||||||
|
return
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, data)
|
full_tree = timeline.to_dict()
|
||||||
|
data[KEY_HISTORY_TREE] = full_tree
|
||||||
|
db_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||||
|
timeline.strip_snapshots()
|
||||||
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
slim_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||||
|
else:
|
||||||
|
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||||
|
save_snapshot = json.loads(json.dumps(data))
|
||||||
|
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
||||||
if refresh_list:
|
if refresh_list:
|
||||||
refresh_list.refresh()
|
refresh_list.refresh()
|
||||||
|
|||||||
+5
-2
@@ -82,6 +82,7 @@ def render_comfy_monitor(state: AppState):
|
|||||||
_live_refreshables = state._live_refreshables
|
_live_refreshables = state._live_refreshables
|
||||||
|
|
||||||
def poll_all():
|
def poll_all():
|
||||||
|
try:
|
||||||
timeout_val = config.get('monitor_timeout', 0)
|
timeout_val = config.get('monitor_timeout', 0)
|
||||||
if timeout_val > 0:
|
if timeout_val > 0:
|
||||||
for key, start_time in list(state.live_toggles.items()):
|
for key, start_time in list(state.live_toggles.items()):
|
||||||
@@ -91,6 +92,8 @@ def render_comfy_monitor(state: AppState):
|
|||||||
_live_checkboxes[key].set_value(False)
|
_live_checkboxes[key].set_value(False)
|
||||||
if key in _live_refreshables:
|
if key in _live_refreshables:
|
||||||
_live_refreshables[key].refresh()
|
_live_refreshables[key].refresh()
|
||||||
|
except RuntimeError:
|
||||||
|
pass # Parent slot deleted during refresh
|
||||||
|
|
||||||
ui.timer(300, poll_all)
|
ui.timer(300, poll_all)
|
||||||
|
|
||||||
@@ -139,7 +142,7 @@ def _render_single_instance(state: AppState, instance_config: dict, index: int,
|
|||||||
|
|
||||||
async def refresh_status():
|
async def refresh_status():
|
||||||
status_container.clear()
|
status_container.clear()
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
res, err = await loop.run_in_executor(
|
res, err = await loop.run_in_executor(
|
||||||
None, lambda: _fetch_blocking(f'{comfy_url}/queue'))
|
None, lambda: _fetch_blocking(f'{comfy_url}/queue'))
|
||||||
with status_container:
|
with status_container:
|
||||||
@@ -237,7 +240,7 @@ def _render_single_instance(state: AppState, instance_config: dict, index: int,
|
|||||||
|
|
||||||
async def check_image():
|
async def check_image():
|
||||||
img_container.clear()
|
img_container.clear()
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
res, err = await loop.run_in_executor(
|
res, err = await loop.run_in_executor(
|
||||||
None, lambda: _fetch_blocking(f'{comfy_url}/history', timeout=2))
|
None, lambda: _fetch_blocking(f'{comfy_url}/history', timeout=2))
|
||||||
with img_container:
|
with img_container:
|
||||||
|
|||||||
+94
-9
@@ -1,11 +1,14 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from nicegui import ui
|
from nicegui import ui
|
||||||
|
|
||||||
from state import AppState
|
from state import AppState
|
||||||
from db import ProjectDB
|
from db import ProjectDB
|
||||||
from utils import save_config, sync_to_db, KEY_BATCH_DATA
|
from utils import save_config, sync_to_db
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -40,13 +43,13 @@ def render_projects_tab(state: AppState):
|
|||||||
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
|
name_input = ui.input('Project Name', placeholder='my_project').classes('w-full')
|
||||||
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
|
desc_input = ui.input('Description (optional)', placeholder='A short description').classes('w-full')
|
||||||
|
|
||||||
def create_project():
|
async def create_project():
|
||||||
name = name_input.value.strip()
|
name = name_input.value.strip()
|
||||||
if not name:
|
if not name:
|
||||||
ui.notify('Please enter a project name', type='warning')
|
ui.notify('Please enter a project name', type='warning')
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
state.db.create_project(name, str(state.current_dir), desc_input.value.strip())
|
await asyncio.to_thread(state.db.create_project, name, str(state.current_dir), desc_input.value.strip())
|
||||||
name_input.set_value('')
|
name_input.set_value('')
|
||||||
desc_input.set_value('')
|
desc_input.set_value('')
|
||||||
ui.notify(f'Created project "{name}"', type='positive')
|
ui.notify(f'Created project "{name}"', type='positive')
|
||||||
@@ -57,14 +60,50 @@ def render_projects_tab(state: AppState):
|
|||||||
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
|
ui.button('Create Project', icon='add', on_click=create_project).classes('w-full')
|
||||||
|
|
||||||
# --- Active project indicator ---
|
# --- Active project indicator ---
|
||||||
|
# Fetch once with file counts and reuse in render_project_list
|
||||||
|
_cached_projects = state.db.list_projects_with_file_counts()
|
||||||
|
|
||||||
if state.current_project:
|
if state.current_project:
|
||||||
|
# Check if active project actually exists in the database
|
||||||
|
project_exists = any(p['name'] == state.current_project for p in _cached_projects)
|
||||||
|
if project_exists:
|
||||||
ui.label(f'Active Project: {state.current_project}').classes(
|
ui.label(f'Active Project: {state.current_project}').classes(
|
||||||
'text-bold text-primary q-pa-sm')
|
'text-bold text-primary q-pa-sm')
|
||||||
|
else:
|
||||||
|
with ui.card().classes('w-full q-pa-sm q-mb-sm').style(
|
||||||
|
'border-left: 3px solid orange;'):
|
||||||
|
ui.label(f'Stale project reference: "{state.current_project}" '
|
||||||
|
'(not found in database)').classes('text-warning')
|
||||||
|
with ui.row().classes('q-gutter-sm'):
|
||||||
|
def clear_stale():
|
||||||
|
state.current_project = ''
|
||||||
|
state.config['current_project'] = ''
|
||||||
|
save_config(state.current_dir,
|
||||||
|
state.config.get('favorites', []),
|
||||||
|
state.config)
|
||||||
|
ui.notify('Cleared stale project reference', type='info')
|
||||||
|
render_project_content.refresh()
|
||||||
|
|
||||||
|
def recreate_project():
|
||||||
|
name = state.current_project
|
||||||
|
try:
|
||||||
|
state.db.create_project(name, str(state.current_dir))
|
||||||
|
ui.notify(f'Recreated project "{name}"', type='positive')
|
||||||
|
render_project_content.refresh()
|
||||||
|
except Exception as e:
|
||||||
|
ui.notify(f'Error: {e}', type='negative')
|
||||||
|
|
||||||
|
ui.button('Clear Reference', icon='clear',
|
||||||
|
on_click=clear_stale).props('flat dense')
|
||||||
|
ui.button('Recreate Project', icon='add_circle',
|
||||||
|
on_click=recreate_project).props('flat dense color=primary')
|
||||||
|
|
||||||
# --- Project list ---
|
# --- Project list ---
|
||||||
@ui.refreshable
|
@ui.refreshable
|
||||||
def render_project_list():
|
def render_project_list():
|
||||||
projects = state.db.list_projects()
|
nonlocal _cached_projects
|
||||||
|
projects = state.db.list_projects_with_file_counts()
|
||||||
|
_cached_projects = projects
|
||||||
if not projects:
|
if not projects:
|
||||||
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
|
ui.label('No projects yet. Create one above.').classes('text-caption q-pa-md')
|
||||||
return
|
return
|
||||||
@@ -80,8 +119,7 @@ def render_projects_tab(state: AppState):
|
|||||||
if proj['description']:
|
if proj['description']:
|
||||||
ui.label(proj['description']).classes('text-caption')
|
ui.label(proj['description']).classes('text-caption')
|
||||||
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
|
ui.label(f'Path: {proj["folder_path"]}').classes('text-caption')
|
||||||
files = state.db.list_data_files(proj['id'])
|
ui.label(f'{proj["file_count"]} data file(s)').classes('text-caption')
|
||||||
ui.label(f'{len(files)} data file(s)').classes('text-caption')
|
|
||||||
|
|
||||||
with ui.row().classes('q-gutter-xs'):
|
with ui.row().classes('q-gutter-xs'):
|
||||||
if not is_active:
|
if not is_active:
|
||||||
@@ -109,14 +147,57 @@ def render_projects_tab(state: AppState):
|
|||||||
ui.button('Deactivate', icon='cancel',
|
ui.button('Deactivate', icon='cancel',
|
||||||
on_click=deactivate).props('flat dense')
|
on_click=deactivate).props('flat dense')
|
||||||
|
|
||||||
|
async def rename_proj(name=proj['name']):
|
||||||
|
new_name = await ui.run_javascript(
|
||||||
|
f'prompt("Rename project:", {json.dumps(name)})',
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
if new_name and new_name.strip() and new_name.strip() != name:
|
||||||
|
new_name = new_name.strip()
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(state.db.rename_project, name, new_name)
|
||||||
|
if state.current_project == name:
|
||||||
|
state.current_project = new_name
|
||||||
|
state.config['current_project'] = new_name
|
||||||
|
save_config(state.current_dir,
|
||||||
|
state.config.get('favorites', []),
|
||||||
|
state.config)
|
||||||
|
ui.notify(f'Renamed to "{new_name}"', type='positive')
|
||||||
|
render_project_list.refresh()
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
ui.notify(f'A project named "{new_name}" already exists',
|
||||||
|
type='warning')
|
||||||
|
except Exception as e:
|
||||||
|
ui.notify(f'Error: {e}', type='negative')
|
||||||
|
|
||||||
|
ui.button('Rename', icon='edit',
|
||||||
|
on_click=rename_proj).props('flat dense')
|
||||||
|
|
||||||
|
async def change_path(name=proj['name'], path=proj['folder_path']):
|
||||||
|
new_path = await ui.run_javascript(
|
||||||
|
f'prompt("New path for project:", {json.dumps(path)})',
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
if new_path and new_path.strip() and new_path.strip() != path:
|
||||||
|
new_path = new_path.strip()
|
||||||
|
if not Path(new_path).is_dir():
|
||||||
|
ui.notify(f'Warning: "{new_path}" does not exist',
|
||||||
|
type='warning')
|
||||||
|
await asyncio.to_thread(state.db.update_project_path, name, new_path)
|
||||||
|
ui.notify(f'Path updated to "{new_path}"', type='positive')
|
||||||
|
render_project_list.refresh()
|
||||||
|
|
||||||
|
ui.button('Path', icon='folder',
|
||||||
|
on_click=change_path).props('flat dense')
|
||||||
|
|
||||||
def import_folder(pid=proj['id'], pname=proj['name']):
|
def import_folder(pid=proj['id'], pname=proj['name']):
|
||||||
_import_folder(state, pid, pname, render_project_list)
|
_import_folder(state, pid, pname, render_project_list)
|
||||||
|
|
||||||
ui.button('Import Folder', icon='folder_open',
|
ui.button('Import Folder', icon='folder_open',
|
||||||
on_click=import_folder).props('flat dense')
|
on_click=import_folder).props('flat dense')
|
||||||
|
|
||||||
def delete_proj(name=proj['name']):
|
async def delete_proj(name=proj['name']):
|
||||||
state.db.delete_project(name)
|
await asyncio.to_thread(state.db.delete_project, name)
|
||||||
if state.current_project == name:
|
if state.current_project == name:
|
||||||
state.current_project = ''
|
state.current_project = ''
|
||||||
state.config['current_project'] = ''
|
state.config['current_project'] = ''
|
||||||
@@ -134,7 +215,7 @@ def render_projects_tab(state: AppState):
|
|||||||
render_project_content()
|
render_project_content()
|
||||||
|
|
||||||
|
|
||||||
def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
|
async def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn):
|
||||||
"""Bulk import all .json files from current directory into a project."""
|
"""Bulk import all .json files from current directory into a project."""
|
||||||
json_files = sorted(state.current_dir.glob('*.json'))
|
json_files = sorted(state.current_dir.glob('*.json'))
|
||||||
json_files = [f for f in json_files if f.name not in (
|
json_files = [f for f in json_files if f.name not in (
|
||||||
@@ -144,6 +225,7 @@ def _import_folder(state: AppState, project_id: int, project_name: str, refresh_
|
|||||||
ui.notify('No JSON files in current directory', type='warning')
|
ui.notify('No JSON files in current directory', type='warning')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def _do_import():
|
||||||
imported = 0
|
imported = 0
|
||||||
skipped = 0
|
skipped = 0
|
||||||
for jf in json_files:
|
for jf in json_files:
|
||||||
@@ -157,6 +239,9 @@ def _import_folder(state: AppState, project_id: int, project_name: str, refresh_
|
|||||||
imported += 1
|
imported += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to import {jf}: {e}")
|
logger.warning(f"Failed to import {jf}: {e}")
|
||||||
|
return imported, skipped
|
||||||
|
|
||||||
|
imported, skipped = await asyncio.to_thread(_do_import)
|
||||||
|
|
||||||
msg = f'Imported {imported} file(s)'
|
msg = f'Imported {imported} file(s)'
|
||||||
if skipped:
|
if skipped:
|
||||||
|
|||||||
+7
-8
@@ -1,4 +1,4 @@
|
|||||||
import copy
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from nicegui import ui
|
from nicegui import ui
|
||||||
@@ -21,11 +21,10 @@ def render_raw_editor(state: AppState):
|
|||||||
|
|
||||||
@ui.refreshable
|
@ui.refreshable
|
||||||
def render_editor():
|
def render_editor():
|
||||||
# Prepare display data
|
# Prepare display data — shallow copy, just pop keys
|
||||||
if hide_history.value:
|
if hide_history.value:
|
||||||
display_data = copy.deepcopy(data)
|
display_data = {k: v for k, v in data.items()
|
||||||
display_data.pop(KEY_HISTORY_TREE, None)
|
if k not in (KEY_HISTORY_TREE, KEY_PROMPT_HISTORY)}
|
||||||
display_data.pop(KEY_PROMPT_HISTORY, None)
|
|
||||||
else:
|
else:
|
||||||
display_data = data
|
display_data = data
|
||||||
|
|
||||||
@@ -40,7 +39,7 @@ def render_raw_editor(state: AppState):
|
|||||||
value=json_str,
|
value=json_str,
|
||||||
).classes('w-full font-mono').props('outlined rows=30')
|
).classes('w-full font-mono').props('outlined rows=30')
|
||||||
|
|
||||||
def do_save():
|
async def do_save():
|
||||||
try:
|
try:
|
||||||
input_data = json.loads(text_area.value)
|
input_data = json.loads(text_area.value)
|
||||||
|
|
||||||
@@ -51,9 +50,9 @@ def render_raw_editor(state: AppState):
|
|||||||
if KEY_PROMPT_HISTORY in data:
|
if KEY_PROMPT_HISTORY in data:
|
||||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
||||||
|
|
||||||
save_json(file_path, input_data)
|
await asyncio.to_thread(save_json, file_path, input_data)
|
||||||
if state.db_enabled and state.current_project and state.db:
|
if state.db_enabled and state.current_project and state.db:
|
||||||
sync_to_db(state.db, state.current_project, file_path, input_data)
|
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, input_data)
|
||||||
|
|
||||||
data.clear()
|
data.clear()
|
||||||
data.update(input_data)
|
data.update(input_data)
|
||||||
|
|||||||
+528
-484
File diff suppressed because it is too large
Load Diff
+3
-3
@@ -208,10 +208,10 @@ class TestHistoryTrees:
|
|||||||
def test_upsert_updates(self, db):
|
def test_upsert_updates(self, db):
|
||||||
pid = db.create_project("p1", "/p1")
|
pid = db.create_project("p1", "/p1")
|
||||||
df_id = db.create_data_file(pid, "batch", "generic")
|
df_id = db.create_data_file(pid, "batch", "generic")
|
||||||
db.save_history_tree(df_id, {"v": 1})
|
db.save_history_tree(df_id, {"snapshots": {}, "v": 1})
|
||||||
db.save_history_tree(df_id, {"v": 2})
|
db.save_history_tree(df_id, {"snapshots": {}, "v": 2})
|
||||||
result = db.get_history_tree(df_id)
|
result = db.get_history_tree(df_id)
|
||||||
assert result == {"v": 2}
|
assert result == {"snapshots": {}, "v": 2}
|
||||||
|
|
||||||
def test_get_nonexistent(self, db):
|
def test_get_nonexistent(self, db):
|
||||||
pid = db.create_project("p1", "/p1")
|
pid = db.create_project("p1", "/p1")
|
||||||
|
|||||||
@@ -203,9 +203,152 @@ class TestProjectLoaderDynamic:
|
|||||||
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
assert ProjectLoaderDynamic.CATEGORY == "utils/json/project"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectSource:
|
||||||
|
def test_input_types(self):
|
||||||
|
from project_loader import ProjectSource
|
||||||
|
inputs = ProjectSource.INPUT_TYPES()
|
||||||
|
assert "manager_url" in inputs["required"]
|
||||||
|
assert "project_name" in inputs["required"]
|
||||||
|
assert "file_name" in inputs["required"]
|
||||||
|
assert "sequence_number" in inputs["required"]
|
||||||
|
assert "label" in inputs["required"]
|
||||||
|
|
||||||
|
def test_outputs_sequence_number(self):
|
||||||
|
from project_loader import ProjectSource
|
||||||
|
assert ProjectSource.RETURN_TYPES == ("INT",)
|
||||||
|
assert ProjectSource.RETURN_NAMES == ("sequence_number",)
|
||||||
|
|
||||||
|
def test_hold_config_returns_sequence_number(self):
|
||||||
|
from project_loader import ProjectSource
|
||||||
|
node = ProjectSource()
|
||||||
|
result = node.hold_config(
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=42,
|
||||||
|
label="my_source"
|
||||||
|
)
|
||||||
|
assert result == (42,)
|
||||||
|
|
||||||
|
def test_category(self):
|
||||||
|
from project_loader import ProjectSource
|
||||||
|
assert ProjectSource.CATEGORY == "utils/json/project"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectKey:
|
||||||
|
def test_input_types(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
inputs = ProjectKey.INPUT_TYPES()
|
||||||
|
assert "source_label" in inputs["required"]
|
||||||
|
assert "key_name" in inputs["required"]
|
||||||
|
assert "key_type" in inputs["required"]
|
||||||
|
|
||||||
|
def test_single_output(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
assert len(ProjectKey.RETURN_TYPES) == 1
|
||||||
|
assert len(ProjectKey.RETURN_NAMES) == 1
|
||||||
|
|
||||||
|
def test_fetch_key_string(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
data = {"prompt": "hello", "seed": 42}
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="my_source",
|
||||||
|
key_name="prompt",
|
||||||
|
key_type="STRING",
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == ("hello",)
|
||||||
|
|
||||||
|
def test_fetch_key_int_coercion(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
data = {"seed": "42"}
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="my_source",
|
||||||
|
key_name="seed",
|
||||||
|
key_type="INT",
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == (42,)
|
||||||
|
|
||||||
|
def test_fetch_key_float_coercion(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
data = {"cfg": "1.5"}
|
||||||
|
with patch("project_loader._fetch_data", return_value=data):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="my_source",
|
||||||
|
key_name="cfg",
|
||||||
|
key_type="FLOAT",
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == (1.5,)
|
||||||
|
|
||||||
|
def test_fetch_key_missing_key(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
with patch("project_loader._fetch_data", return_value={}):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="my_source",
|
||||||
|
key_name="nonexistent",
|
||||||
|
key_type="STRING",
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == ("",)
|
||||||
|
|
||||||
|
def test_fetch_key_network_error_returns_default(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||||
|
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="my_source",
|
||||||
|
key_name="prompt",
|
||||||
|
key_type="STRING",
|
||||||
|
manager_url="http://localhost:8080",
|
||||||
|
project_name="proj1",
|
||||||
|
file_name="batch_i2v",
|
||||||
|
sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == ("",)
|
||||||
|
|
||||||
|
def test_fetch_key_error_returns_int_default(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
node = ProjectKey()
|
||||||
|
error_resp = {"error": "http_error", "status": 404, "message": "Not found"}
|
||||||
|
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||||
|
result = node.fetch_key(
|
||||||
|
source_label="s", key_name="seed", key_type="INT",
|
||||||
|
manager_url="http://localhost:8080", project_name="p",
|
||||||
|
file_name="f", sequence_number=1,
|
||||||
|
)
|
||||||
|
assert result == (0,)
|
||||||
|
|
||||||
|
def test_category(self):
|
||||||
|
from project_loader import ProjectKey
|
||||||
|
assert ProjectKey.CATEGORY == "utils/json/project"
|
||||||
|
|
||||||
|
|
||||||
class TestNodeMappings:
|
class TestNodeMappings:
|
||||||
def test_mappings_exist(self):
|
def test_mappings_exist(self):
|
||||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 1
|
assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS
|
||||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 1
|
assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS
|
||||||
|
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 3
|
||||||
|
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 3
|
||||||
|
|||||||
@@ -0,0 +1,159 @@
|
|||||||
|
import pytest
|
||||||
|
from snapshot_timeline import SnapshotTimeline, diff_snapshots
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_creates_snapshot():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
sid = tl.record({"batch_data": [{"seed": 42}]}, note="first")
|
||||||
|
assert sid in tl.snapshots
|
||||||
|
assert tl.current_id == sid
|
||||||
|
assert tl.snapshots[sid]["note"] == "first"
|
||||||
|
assert tl.snapshots[sid]["auto"] is False
|
||||||
|
assert tl.snapshots[sid]["seq_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_auto_flag():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
sid = tl.record({"batch_data": []}, note="auto save", auto=True)
|
||||||
|
assert tl.snapshots[sid]["auto"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_records():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
id1 = tl.record({"batch_data": [{"a": 1}]}, note="one")
|
||||||
|
id2 = tl.record({"batch_data": [{"b": 2}]}, note="two")
|
||||||
|
assert len(tl.snapshots) == 2
|
||||||
|
assert tl.current_id == id2
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||||
|
d = tl.to_dict()
|
||||||
|
tl2 = SnapshotTimeline(d)
|
||||||
|
assert tl2.current_id == tl.current_id
|
||||||
|
assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_from_history_tree():
|
||||||
|
"""Old HistoryTree format should be flattened into snapshots."""
|
||||||
|
old_data = {
|
||||||
|
"nodes": {
|
||||||
|
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}},
|
||||||
|
"bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}},
|
||||||
|
},
|
||||||
|
"branches": {"main": "bbb"},
|
||||||
|
"head_id": "bbb",
|
||||||
|
}
|
||||||
|
tl = SnapshotTimeline(old_data)
|
||||||
|
assert len(tl.snapshots) == 2
|
||||||
|
assert tl.current_id == "bbb"
|
||||||
|
assert tl.snapshots["aaa"]["note"] == "First"
|
||||||
|
assert tl.snapshots["bbb"]["note"] == "Second"
|
||||||
|
# Data should be preserved
|
||||||
|
assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_from_history_tree_no_data():
|
||||||
|
"""Slim tree nodes (no inline data) should still migrate."""
|
||||||
|
old_data = {
|
||||||
|
"nodes": {
|
||||||
|
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"},
|
||||||
|
},
|
||||||
|
"branches": {"main": "aaa"},
|
||||||
|
"head_id": "aaa",
|
||||||
|
}
|
||||||
|
tl = SnapshotTimeline(old_data)
|
||||||
|
assert len(tl.snapshots) == 1
|
||||||
|
assert tl.snapshots["aaa"]["seq_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_legacy_prompt_history():
|
||||||
|
legacy = {
|
||||||
|
"prompt_history": [
|
||||||
|
{"note": "A", "seed": 1},
|
||||||
|
{"note": "B", "seed": 2},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tl = SnapshotTimeline(legacy)
|
||||||
|
assert len(tl.snapshots) == 2
|
||||||
|
assert tl.current_id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_toggle_pin():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
sid = tl.record({"batch_data": []}, note="test")
|
||||||
|
assert tl.snapshots[sid]["pinned"] is False
|
||||||
|
result = tl.toggle_pin(sid)
|
||||||
|
assert result is True
|
||||||
|
assert tl.snapshots[sid]["pinned"] is True
|
||||||
|
result = tl.toggle_pin(sid)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_snapshot():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
id1 = tl.record({"batch_data": []}, note="one")
|
||||||
|
id2 = tl.record({"batch_data": []}, note="two")
|
||||||
|
tl.delete(id2)
|
||||||
|
assert id2 not in tl.snapshots
|
||||||
|
assert tl.current_id == id1
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_snapshots():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
sid = tl.record({"batch_data": []}, note="only")
|
||||||
|
tl.delete(sid)
|
||||||
|
assert len(tl.snapshots) == 0
|
||||||
|
assert tl.current_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_snapshots():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
tl.record({"batch_data": [{"a": 1}]}, note="test")
|
||||||
|
tl.strip_snapshots()
|
||||||
|
for snap in tl.snapshots.values():
|
||||||
|
assert "data" not in snap
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_snapshot_data():
|
||||||
|
tl = SnapshotTimeline({})
|
||||||
|
sid = tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||||
|
data = tl.get_snapshot_data(sid)
|
||||||
|
assert data == {"batch_data": [{"x": 1}]}
|
||||||
|
assert tl.get_snapshot_data("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- diff_snapshots tests ---
|
||||||
|
|
||||||
|
def test_diff_unchanged():
|
||||||
|
batch = [{"sequence_number": 1, "seed": 42}]
|
||||||
|
result = diff_snapshots(batch, batch)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["status"] == "unchanged"
|
||||||
|
assert result[0]["changes"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_diff_changed():
|
||||||
|
old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}]
|
||||||
|
new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}]
|
||||||
|
result = diff_snapshots(old, new)
|
||||||
|
assert result[0]["status"] == "changed"
|
||||||
|
assert len(result[0]["changes"]) == 1
|
||||||
|
assert result[0]["changes"][0]["field"] == "seed"
|
||||||
|
assert result[0]["changes"][0]["old"] == 42
|
||||||
|
assert result[0]["changes"][0]["new"] == 99
|
||||||
|
|
||||||
|
|
||||||
|
def test_diff_added_and_removed():
|
||||||
|
old = [{"sequence_number": 1, "seed": 1}]
|
||||||
|
new = [{"sequence_number": 2, "seed": 2}]
|
||||||
|
result = diff_snapshots(old, new)
|
||||||
|
assert len(result) == 2
|
||||||
|
statuses = {r["seq_num"]: r["status"] for r in result}
|
||||||
|
assert statuses[1] == "removed"
|
||||||
|
assert statuses[2] == "added"
|
||||||
|
|
||||||
|
|
||||||
|
def test_diff_empty():
|
||||||
|
assert diff_snapshots([], []) == []
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -30,6 +31,7 @@ DEFAULTS = {
|
|||||||
"cfg": 1.5,
|
"cfg": 1.5,
|
||||||
|
|
||||||
# --- Settings ---
|
# --- Settings ---
|
||||||
|
"mode": 0,
|
||||||
"camera": "static",
|
"camera": "static",
|
||||||
"flf": 0.0,
|
"flf": 0.0,
|
||||||
|
|
||||||
@@ -47,10 +49,19 @@ DEFAULTS = {
|
|||||||
"reference path": "",
|
"reference path": "",
|
||||||
"flf image path": "",
|
"flf image path": "",
|
||||||
|
|
||||||
# --- LoRAs ---
|
# --- LoRAs (name as STRING, strength as FLOAT) ---
|
||||||
"lora 1 high": "", "lora 1 low": "",
|
"lora 1 high": "",
|
||||||
"lora 2 high": "", "lora 2 low": "",
|
"lora 1 high strength": 1.0,
|
||||||
"lora 3 high": "", "lora 3 low": ""
|
"lora 1 low": "",
|
||||||
|
"lora 1 low strength": 1.0,
|
||||||
|
"lora 2 high": "",
|
||||||
|
"lora 2 high strength": 1.0,
|
||||||
|
"lora 2 low": "",
|
||||||
|
"lora 2 low strength": 1.0,
|
||||||
|
"lora 3 high": "",
|
||||||
|
"lora 3 high strength": 1.0,
|
||||||
|
"lora 3 low": "",
|
||||||
|
"lora 3 low strength": 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
CONFIG_FILE = Path(".editor_config.json")
|
CONFIG_FILE = Path(".editor_config.json")
|
||||||
@@ -112,14 +123,17 @@ def save_config(current_dir, favorites, extra_data=None):
|
|||||||
existing = load_config()
|
existing = load_config()
|
||||||
data.update(existing)
|
data.update(existing)
|
||||||
|
|
||||||
data["last_dir"] = str(current_dir)
|
|
||||||
data["favorites"] = favorites
|
|
||||||
|
|
||||||
if extra_data:
|
if extra_data:
|
||||||
data.update(extra_data)
|
data.update(extra_data)
|
||||||
|
|
||||||
with open(CONFIG_FILE, 'w') as f:
|
# Force-set explicit params last so extra_data can't override them
|
||||||
|
data["last_dir"] = str(current_dir)
|
||||||
|
data["favorites"] = favorites
|
||||||
|
|
||||||
|
tmp = CONFIG_FILE.with_suffix('.json.tmp')
|
||||||
|
with open(tmp, 'w') as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
os.replace(tmp, CONFIG_FILE)
|
||||||
|
|
||||||
def load_snippets():
|
def load_snippets():
|
||||||
if SNIPPETS_FILE.exists():
|
if SNIPPETS_FILE.exists():
|
||||||
@@ -131,27 +145,96 @@ def load_snippets():
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def save_snippets(snippets):
|
def save_snippets(snippets):
|
||||||
with open(SNIPPETS_FILE, 'w') as f:
|
tmp = SNIPPETS_FILE.with_suffix('.json.tmp')
|
||||||
|
with open(tmp, 'w') as f:
|
||||||
json.dump(snippets, f, indent=4)
|
json.dump(snippets, f, indent=4)
|
||||||
|
os.replace(tmp, SNIPPETS_FILE)
|
||||||
|
|
||||||
|
def _migrate_lora_keys(data: dict) -> None:
|
||||||
|
"""Split combined lora 'name:strength' into separate name and strength keys.
|
||||||
|
|
||||||
|
Handles legacy formats:
|
||||||
|
1. <lora:Name:0.5> → name_key='Name', str_key=0.5
|
||||||
|
2. 'Name:0.5' (merged) → name_key='Name', str_key=0.5
|
||||||
|
3. Already split (name_key + str_key exist) → no change
|
||||||
|
"""
|
||||||
|
for item in data.get(KEY_BATCH_DATA, []):
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
for idx in range(1, 4):
|
||||||
|
for tier in ('high', 'low'):
|
||||||
|
name_key = f'lora {idx} {tier}'
|
||||||
|
str_key = f'lora {idx} {tier} strength'
|
||||||
|
raw = str(item.get(name_key, ''))
|
||||||
|
|
||||||
|
if raw.startswith('<lora:'):
|
||||||
|
# Legacy <lora:Name:0.5> format
|
||||||
|
inner = raw.replace('<lora:', '').replace('>', '')
|
||||||
|
if ':' in inner:
|
||||||
|
parts = inner.rsplit(':', 1)
|
||||||
|
item[name_key] = parts[0]
|
||||||
|
try:
|
||||||
|
item[str_key] = float(parts[1])
|
||||||
|
except ValueError:
|
||||||
|
item[str_key] = 1.0
|
||||||
|
else:
|
||||||
|
item[name_key] = inner
|
||||||
|
if str_key not in item:
|
||||||
|
item[str_key] = 1.0
|
||||||
|
elif ':' in raw and raw:
|
||||||
|
# Combined 'name:strength' format → split
|
||||||
|
parts = raw.rsplit(':', 1)
|
||||||
|
try:
|
||||||
|
strength = float(parts[1])
|
||||||
|
item[name_key] = parts[0]
|
||||||
|
item[str_key] = strength
|
||||||
|
except ValueError:
|
||||||
|
# Not a valid strength, leave as-is
|
||||||
|
if str_key not in item:
|
||||||
|
item[str_key] = 1.0
|
||||||
|
elif raw:
|
||||||
|
# Name exists without colon, ensure strength key exists
|
||||||
|
if str_key not in item:
|
||||||
|
item[str_key] = 1.0
|
||||||
|
# If name is empty, don't add a strength key
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
|
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
|
||||||
|
t0 = time.time()
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
try:
|
try:
|
||||||
with open(path, 'r') as f:
|
with open(path, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data, path.stat().st_mtime
|
t1 = time.time()
|
||||||
|
_migrate_lora_keys(data)
|
||||||
|
t2 = time.time()
|
||||||
|
mtime = path.stat().st_mtime
|
||||||
|
logger.info("load_json %s: read=%.3fs migrate=%.3fs total=%.3fs",
|
||||||
|
path.name, t1 - t0, t2 - t1, t2 - t0)
|
||||||
|
return data, mtime
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading JSON: {e}")
|
logger.error(f"Error loading JSON: {e}")
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
|
|
||||||
def save_json(path: str | Path, data: dict[str, Any]) -> None:
|
def save_json(path: str | Path, data: dict[str, Any]) -> None:
|
||||||
|
t0 = time.time()
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
tmp = path.with_suffix('.json.tmp')
|
tmp = path.with_suffix('.json.tmp')
|
||||||
with open(tmp, 'w') as f:
|
with open(tmp, 'w') as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
os.replace(tmp, path)
|
os.replace(tmp, path)
|
||||||
|
logger.info("save_json %s: %.3fs", path.name, time.time() - t0)
|
||||||
|
|
||||||
|
|
||||||
|
def snapshot_data(data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Create a thread-safe deep copy via JSON roundtrip.
|
||||||
|
|
||||||
|
Must be called on the main thread before passing data to asyncio.to_thread,
|
||||||
|
to avoid 'dict changed size during iteration' when the UI mutates data.
|
||||||
|
"""
|
||||||
|
return json.loads(json.dumps(data))
|
||||||
|
|
||||||
def get_file_mtime(path: str | Path) -> float:
|
def get_file_mtime(path: str | Path) -> float:
|
||||||
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
||||||
@@ -166,6 +249,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
||||||
and saves the history_tree. All writes happen in a single transaction.
|
and saves the history_tree. All writes happen in a single transaction.
|
||||||
"""
|
"""
|
||||||
|
t0 = time.time()
|
||||||
if not db or not project_name:
|
if not db or not project_name:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@@ -177,11 +261,11 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
# Use a single transaction for atomicity
|
# Use a single transaction for atomicity
|
||||||
db.conn.execute("BEGIN IMMEDIATE")
|
db.conn.execute("BEGIN IMMEDIATE")
|
||||||
try:
|
try:
|
||||||
|
now = time.time()
|
||||||
df = db.get_data_file(proj["id"], file_name)
|
df = db.get_data_file(proj["id"], file_name)
|
||||||
top_level = {k: v for k, v in data.items()
|
top_level = {k: v for k, v in data.items()
|
||||||
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||||
if not df:
|
if not df:
|
||||||
now = __import__('time').time()
|
|
||||||
cur = db.conn.execute(
|
cur = db.conn.execute(
|
||||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
@@ -191,7 +275,6 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
else:
|
else:
|
||||||
df_id = df["id"]
|
df_id = df["id"]
|
||||||
# Update top_level metadata
|
# Update top_level metadata
|
||||||
now = __import__('time').time()
|
|
||||||
db.conn.execute(
|
db.conn.execute(
|
||||||
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
||||||
(json.dumps(top_level), now, df_id),
|
(json.dumps(top_level), now, df_id),
|
||||||
@@ -200,27 +283,74 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
# Sync sequences
|
# Sync sequences
|
||||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||||
if isinstance(batch_data, list):
|
if isinstance(batch_data, list):
|
||||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
new_seq_nums = set()
|
||||||
for item in batch_data:
|
for item in batch_data:
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||||
now = __import__('time').time()
|
new_seq_nums.add(seq_num)
|
||||||
db.conn.execute(
|
db.conn.execute(
|
||||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||||
"VALUES (?, ?, ?, ?)",
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
||||||
(df_id, seq_num, json.dumps(item), now),
|
(df_id, seq_num, json.dumps(item), now),
|
||||||
)
|
)
|
||||||
|
# Remove sequences that no longer exist
|
||||||
|
if new_seq_nums:
|
||||||
|
placeholders = ','.join('?' * len(new_seq_nums))
|
||||||
|
db.conn.execute(
|
||||||
|
f"DELETE FROM sequences WHERE data_file_id = ? AND sequence_number NOT IN ({placeholders})",
|
||||||
|
(df_id, *new_seq_nums),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||||
|
|
||||||
# Sync history tree
|
# Sync history tree (extract snapshot data into separate table)
|
||||||
|
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||||
history_tree = data.get(KEY_HISTORY_TREE)
|
history_tree = data.get(KEY_HISTORY_TREE)
|
||||||
if history_tree and isinstance(history_tree, dict):
|
if history_tree and isinstance(history_tree, dict):
|
||||||
now = __import__('time').time()
|
# Detect format: new has "snapshots", old has "nodes"
|
||||||
|
if "snapshots" in history_tree:
|
||||||
|
entries = history_tree.get("snapshots", {})
|
||||||
|
else:
|
||||||
|
entries = history_tree.get("nodes", {})
|
||||||
|
slim_tree = dict(history_tree)
|
||||||
|
slim_entries = {}
|
||||||
|
for eid, entry in entries.items():
|
||||||
|
snap = entry.get("data")
|
||||||
|
if snap:
|
||||||
|
db.conn.execute(
|
||||||
|
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||||
|
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, eid, json.dumps(snap), now),
|
||||||
|
)
|
||||||
|
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||||
|
# Write back slim version using the correct key
|
||||||
|
if "snapshots" in history_tree:
|
||||||
|
slim_tree["snapshots"] = slim_entries
|
||||||
|
else:
|
||||||
|
slim_tree["nodes"] = slim_entries
|
||||||
db.conn.execute(
|
db.conn.execute(
|
||||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
"VALUES (?, ?, ?) "
|
"VALUES (?, ?, ?) "
|
||||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
(df_id, json.dumps(history_tree), now),
|
(df_id, json.dumps(slim_tree), now),
|
||||||
|
)
|
||||||
|
# Clean up orphaned snapshots
|
||||||
|
current_ids = set(entries.keys())
|
||||||
|
if current_ids:
|
||||||
|
placeholders = ",".join("?" for _ in current_ids)
|
||||||
|
db.conn.execute(
|
||||||
|
f"DELETE FROM history_snapshots WHERE data_file_id = ? "
|
||||||
|
f"AND node_id NOT IN ({placeholders})",
|
||||||
|
(df_id, *current_ids),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
db.conn.execute(
|
||||||
|
"DELETE FROM history_snapshots WHERE data_file_id = ?",
|
||||||
|
(df_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.conn.execute("COMMIT")
|
db.conn.execute("COMMIT")
|
||||||
@@ -232,14 +362,18 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"sync_to_db failed: {e}")
|
logger.warning(f"sync_to_db failed: {e}")
|
||||||
|
return
|
||||||
|
batch_count = len(data.get(KEY_BATCH_DATA, []))
|
||||||
|
logger.info("sync_to_db %s (%d seqs): %.3fs",
|
||||||
|
Path(file_path).name, batch_count, time.time() - t0)
|
||||||
|
|
||||||
|
|
||||||
def generate_templates(current_dir: Path) -> None:
|
def generate_templates(current_dir: Path) -> None:
|
||||||
"""Creates batch template files if folder is empty."""
|
"""Creates batch template files if folder is empty."""
|
||||||
first = DEFAULTS.copy()
|
first = copy.deepcopy(DEFAULTS)
|
||||||
first[KEY_SEQUENCE_NUMBER] = 1
|
first[KEY_SEQUENCE_NUMBER] = 1
|
||||||
save_json(current_dir / "batch_prompt_i2v.json", {KEY_BATCH_DATA: [first]})
|
save_json(current_dir / "batch_prompt_i2v.json", {KEY_BATCH_DATA: [first]})
|
||||||
|
|
||||||
first2 = DEFAULTS.copy()
|
first2 = copy.deepcopy(DEFAULTS)
|
||||||
first2[KEY_SEQUENCE_NUMBER] = 1
|
first2[KEY_SEQUENCE_NUMBER] = 1
|
||||||
save_json(current_dir / "batch_prompt_vace_extend.json", {KEY_BATCH_DATA: [first2]})
|
save_json(current_dir / "batch_prompt_vace_extend.json", {KEY_BATCH_DATA: [first2]})
|
||||||
|
|||||||
+18
-12
@@ -34,7 +34,7 @@ app.registerExtension({
|
|||||||
|
|
||||||
// Auto-refresh with 500ms debounce on widget changes
|
// Auto-refresh with 500ms debounce on widget changes
|
||||||
this._refreshTimer = null;
|
this._refreshTimer = null;
|
||||||
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number"];
|
const autoRefreshWidgets = ["project_name", "file_name", "sequence_number", "refresh"];
|
||||||
for (const widgetName of autoRefreshWidgets) {
|
for (const widgetName of autoRefreshWidgets) {
|
||||||
const w = this.widgets?.find(w => w.name === widgetName);
|
const w = this.widgets?.find(w => w.name === widgetName);
|
||||||
if (w) {
|
if (w) {
|
||||||
@@ -117,11 +117,11 @@ app.registerExtension({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store keys and types in hidden widgets for persistence (comma-separated)
|
// Store keys and types in hidden widgets for persistence (JSON)
|
||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||||
if (okWidget) okWidget.value = keys.join(",");
|
if (okWidget) okWidget.value = JSON.stringify(keys);
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||||
if (otWidget) otWidget.value = types.join(",");
|
if (otWidget) otWidget.value = JSON.stringify(types);
|
||||||
|
|
||||||
// Slot 0 is always total_sequences (INT) — ensure it exists
|
// Slot 0 is always total_sequences (INT) — ensure it exists
|
||||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
@@ -198,12 +198,18 @@ app.registerExtension({
|
|||||||
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
const okWidget = this.widgets?.find(w => w.name === "output_keys");
|
||||||
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
const otWidget = this.widgets?.find(w => w.name === "output_types");
|
||||||
|
|
||||||
const keys = okWidget?.value
|
let keys = [];
|
||||||
? okWidget.value.split(",").filter(k => k.trim())
|
let types = [];
|
||||||
: [];
|
if (okWidget?.value) {
|
||||||
const types = otWidget?.value
|
try { keys = JSON.parse(okWidget.value); } catch (_) {
|
||||||
? otWidget.value.split(",")
|
keys = okWidget.value.split(",").filter(k => k.trim());
|
||||||
: [];
|
}
|
||||||
|
}
|
||||||
|
if (otWidget?.value) {
|
||||||
|
try { types = JSON.parse(otWidget.value); } catch (_) {
|
||||||
|
types = otWidget.value.split(",");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure slot 0 is total_sequences (INT)
|
// Ensure slot 0 is total_sequences (INT)
|
||||||
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
if (this.outputs.length === 0 || this.outputs[0].name !== "total_sequences") {
|
||||||
@@ -245,8 +251,8 @@ app.registerExtension({
|
|||||||
} else if (this.outputs.length > 1) {
|
} else if (this.outputs.length > 1) {
|
||||||
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
// Widget values empty but serialized dynamic outputs exist — sync widgets
|
||||||
const dynamicOutputs = this.outputs.slice(1);
|
const dynamicOutputs = this.outputs.slice(1);
|
||||||
if (okWidget) okWidget.value = dynamicOutputs.map(o => o.name).join(",");
|
if (okWidget) okWidget.value = JSON.stringify(dynamicOutputs.map(o => o.name));
|
||||||
if (otWidget) otWidget.value = dynamicOutputs.map(o => o.type).join(",");
|
if (otWidget) otWidget.value = JSON.stringify(dynamicOutputs.map(o => o.type));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.setSize(this.computeSize());
|
this.setSize(this.computeSize());
|
||||||
|
|||||||
@@ -0,0 +1,275 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "json.manager.project.key",
|
||||||
|
|
||||||
|
// Re-sync all ProjectKey nodes from their sources before queueing
|
||||||
|
// This fixes stale config when the user edits a ProjectSource after
|
||||||
|
// a ProjectKey already selected it.
|
||||||
|
async beforeQueuePrompt() {
|
||||||
|
if (!app.graph?._nodes) return;
|
||||||
|
for (const node of app.graph._nodes) {
|
||||||
|
if (node.type === "ProjectKey" && node._syncFromSource) {
|
||||||
|
node._syncFromSource();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
|
if (nodeData.name !== "ProjectKey") return;
|
||||||
|
|
||||||
|
// Helper: properly hide a widget (works for all types including INT)
|
||||||
|
function hideWidget(widget) {
|
||||||
|
if (widget.origType === undefined) widget.origType = widget.type;
|
||||||
|
widget.type = "hidden";
|
||||||
|
widget.hidden = true;
|
||||||
|
widget.computeSize = () => [0, -4];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: replace a STRING widget with a proper combo widget
|
||||||
|
function replaceWithCombo(node, name, values, callback) {
|
||||||
|
const idx = node.widgets?.findIndex(w => w.name === name);
|
||||||
|
if (idx === -1 || idx === undefined) return null;
|
||||||
|
const oldWidget = node.widgets[idx];
|
||||||
|
const savedValue = oldWidget.value || "";
|
||||||
|
// Ensure values list is never empty (combo shows undefined otherwise)
|
||||||
|
const comboValues = values.length > 0 ? values : [""];
|
||||||
|
// Always preserve saved value — it may not be in the list yet (load-order race)
|
||||||
|
if (savedValue && !comboValues.includes(savedValue)) {
|
||||||
|
comboValues.unshift(savedValue);
|
||||||
|
}
|
||||||
|
const defaultValue = savedValue || comboValues[0];
|
||||||
|
// Remove old STRING widget
|
||||||
|
node.widgets.splice(idx, 1);
|
||||||
|
// Insert a real combo widget at the same position
|
||||||
|
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
||||||
|
// Move it from the end to the original position
|
||||||
|
if (node.widgets.length > 1) {
|
||||||
|
node.widgets.splice(node.widgets.length - 1, 1);
|
||||||
|
node.widgets.splice(idx, 0, combo);
|
||||||
|
}
|
||||||
|
return combo;
|
||||||
|
}
|
||||||
|
|
||||||
|
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
origOnNodeCreated?.apply(this, arguments);
|
||||||
|
this._configured = false;
|
||||||
|
|
||||||
|
// Hide the connection-config widgets (synced from source by JS)
|
||||||
|
for (const name of ["manager_url", "project_name", "file_name", "sequence_number", "key_type"]) {
|
||||||
|
const w = this.widgets?.find(w => w.name === name);
|
||||||
|
if (w) hideWidget(w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace source_label STRING with a proper combo widget
|
||||||
|
const node = this;
|
||||||
|
const sourceLabels = this._getSourceLabels?.() || [];
|
||||||
|
const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) {
|
||||||
|
node._syncFromSource();
|
||||||
|
node._refreshKeys();
|
||||||
|
});
|
||||||
|
// Set first available source or "none" placeholder
|
||||||
|
if (srcCombo) srcCombo.value = sourceLabels[0] || "";
|
||||||
|
|
||||||
|
// Replace key_name STRING with a proper combo widget
|
||||||
|
const keyCombo = replaceWithCombo(this, "key_name", [], function (value) {
|
||||||
|
node._applyKeySelection();
|
||||||
|
});
|
||||||
|
if (keyCombo) keyCombo.value = "";
|
||||||
|
|
||||||
|
queueMicrotask(() => {
|
||||||
|
if (!this._configured) {
|
||||||
|
// New node — set output to a generic slot
|
||||||
|
if (this.outputs.length === 0) {
|
||||||
|
this.addOutput("value", "*");
|
||||||
|
}
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Find all ProjectSource nodes and their labels (deduplicated) ---
|
||||||
|
nodeType.prototype._getSourceLabels = function () {
|
||||||
|
const seen = new Set();
|
||||||
|
const labels = [];
|
||||||
|
if (!this.graph) return labels;
|
||||||
|
for (const node of this.graph._nodes) {
|
||||||
|
if (node.type === "ProjectSource") {
|
||||||
|
const lw = node.widgets?.find(w => w.name === "label");
|
||||||
|
if (lw?.value && !seen.has(lw.value)) {
|
||||||
|
seen.add(lw.value);
|
||||||
|
labels.push(lw.value);
|
||||||
|
} else if (lw?.value && seen.has(lw.value)) {
|
||||||
|
console.warn(`[ProjectKey] Duplicate source label "${lw.value}" (node ${node.id}) — only first will be used`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return labels;
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Find the ProjectSource node matching a label ---
|
||||||
|
nodeType.prototype._findSource = function (label) {
|
||||||
|
if (!this.graph || !label) return null;
|
||||||
|
for (const node of this.graph._nodes) {
|
||||||
|
if (node.type === "ProjectSource") {
|
||||||
|
const lw = node.widgets?.find(w => w.name === "label");
|
||||||
|
if (lw?.value === label) return node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Copy config from source node into hidden widgets ---
|
||||||
|
nodeType.prototype._syncFromSource = function () {
|
||||||
|
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||||
|
const source = this._findSource(srcWidget?.value);
|
||||||
|
if (!source) {
|
||||||
|
console.log(`[ProjectKey] _syncFromSource id=${this.id}: no source found for label="${srcWidget?.value}"`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
||||||
|
const dst = this.widgets?.find(w => w.name === name);
|
||||||
|
const src = source.widgets?.find(w => w.name === name);
|
||||||
|
if (dst && src) {
|
||||||
|
dst.value = src.value;
|
||||||
|
console.log(`[ProjectKey] _syncFromSource id=${this.id}: ${name}="${src.value}"`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Fetch keys from API and populate key_name dropdown ---
|
||||||
|
nodeType.prototype._refreshKeys = async function () {
|
||||||
|
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
||||||
|
const projW = this.widgets?.find(w => w.name === "project_name");
|
||||||
|
const fileW = this.widgets?.find(w => w.name === "file_name");
|
||||||
|
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
||||||
|
|
||||||
|
console.log(`[ProjectKey] _refreshKeys id=${this.id}: url="${urlW?.value}" project="${projW?.value}" file="${fileW?.value}" seq=${seqW?.value}`);
|
||||||
|
if (!urlW?.value || !projW?.value || !fileW?.value) {
|
||||||
|
console.log(`[ProjectKey] _refreshKeys: skipped (missing config)`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resp = await api.fetchApi(
|
||||||
|
`/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}`
|
||||||
|
);
|
||||||
|
if (!resp.ok) return;
|
||||||
|
|
||||||
|
const data = await resp.json();
|
||||||
|
if (data.error || !Array.isArray(data.keys)) return;
|
||||||
|
|
||||||
|
// Store keys/types for lookup
|
||||||
|
this._availableKeys = data.keys;
|
||||||
|
this._availableTypes = data.types;
|
||||||
|
|
||||||
|
// Update key_name combo options only — never change the selection
|
||||||
|
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||||
|
if (keyWidget) {
|
||||||
|
keyWidget.options.values = data.keys;
|
||||||
|
// Selection is sticky: user must change it manually
|
||||||
|
this._applyKeySelection();
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error("[ProjectKey] Failed to refresh keys:", e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Update output slot based on selected key ---
|
||||||
|
nodeType.prototype._applyKeySelection = function () {
|
||||||
|
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||||
|
if (!keyWidget?.value) return;
|
||||||
|
|
||||||
|
const keyIdx = (this._availableKeys || []).indexOf(keyWidget.value);
|
||||||
|
const keyType = keyIdx >= 0 ? (this._availableTypes[keyIdx] || "*") : "*";
|
||||||
|
|
||||||
|
// Update hidden key_type widget
|
||||||
|
const ktWidget = this.widgets?.find(w => w.name === "key_type");
|
||||||
|
if (ktWidget) ktWidget.value = keyType;
|
||||||
|
|
||||||
|
// Update output slot
|
||||||
|
if (this.outputs.length > 0) {
|
||||||
|
this.outputs[0].name = keyWidget.value;
|
||||||
|
this.outputs[0].label = keyWidget.value;
|
||||||
|
this.outputs[0].type = keyType;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.title = keyWidget.value ? `Key: ${keyWidget.value}` : "Project Key";
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Sync config on click (lazy, no key refresh to avoid race) ---
|
||||||
|
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
||||||
|
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
||||||
|
origOnMouseDown?.apply(this, arguments);
|
||||||
|
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||||
|
if (srcWidget) {
|
||||||
|
srcWidget.options.values = this._getSourceLabels();
|
||||||
|
}
|
||||||
|
// Sync config values from source (synchronous, safe)
|
||||||
|
this._syncFromSource();
|
||||||
|
};
|
||||||
|
|
||||||
|
// --- Restore state on workflow load ---
|
||||||
|
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function (info) {
|
||||||
|
origOnConfigure?.apply(this, arguments);
|
||||||
|
this._configured = true;
|
||||||
|
|
||||||
|
// Hide config widgets
|
||||||
|
for (const name of ["manager_url", "project_name", "file_name", "sequence_number", "key_type"]) {
|
||||||
|
const w = this.widgets?.find(w => w.name === name);
|
||||||
|
if (w) hideWidget(w);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure source_label is a proper combo (may still be STRING from serialization)
|
||||||
|
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||||
|
if (srcWidget && srcWidget.type !== "combo") {
|
||||||
|
const node = this;
|
||||||
|
replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) {
|
||||||
|
node._syncFromSource();
|
||||||
|
node._refreshKeys();
|
||||||
|
});
|
||||||
|
} else if (srcWidget) {
|
||||||
|
srcWidget.options.values = this._getSourceLabels();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure key_name is a proper combo
|
||||||
|
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||||
|
if (keyWidget && keyWidget.type !== "combo") {
|
||||||
|
const node = this;
|
||||||
|
replaceWithCombo(this, "key_name", [], function (value) {
|
||||||
|
node._applyKeySelection();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-find widgets after possible replacement
|
||||||
|
const finalKeyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||||
|
|
||||||
|
// Update title from saved key
|
||||||
|
if (finalKeyWidget?.value) {
|
||||||
|
this.title = `Key: ${finalKeyWidget.value}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore output slot name from saved key_name
|
||||||
|
if (finalKeyWidget?.value && this.outputs.length > 0) {
|
||||||
|
this.outputs[0].name = finalKeyWidget.value;
|
||||||
|
this.outputs[0].label = finalKeyWidget.value;
|
||||||
|
const ktWidget = this.widgets?.find(w => w.name === "key_type");
|
||||||
|
if (ktWidget?.value) this.outputs[0].type = ktWidget.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.setSize(this.computeSize());
|
||||||
|
|
||||||
|
// Deferred: sync from source and refresh key dropdown once graph is ready
|
||||||
|
const node = this;
|
||||||
|
queueMicrotask(() => {
|
||||||
|
node._syncFromSource();
|
||||||
|
node._refreshKeys();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: "json.manager.project.source",
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
|
if (nodeData.name !== "ProjectSource") return;
|
||||||
|
|
||||||
|
// Helper: replace a STRING widget with a proper combo widget
|
||||||
|
function replaceWithCombo(node, name, values, callback) {
|
||||||
|
const idx = node.widgets?.findIndex(w => w.name === name);
|
||||||
|
if (idx === -1 || idx === undefined) return null;
|
||||||
|
const oldWidget = node.widgets[idx];
|
||||||
|
const savedValue = oldWidget.value || "";
|
||||||
|
const comboValues = values.length > 0 ? values : [""];
|
||||||
|
// Always preserve saved value (may not be in list yet)
|
||||||
|
if (savedValue && !comboValues.includes(savedValue)) {
|
||||||
|
comboValues.unshift(savedValue);
|
||||||
|
}
|
||||||
|
const defaultValue = savedValue || comboValues[0];
|
||||||
|
node.widgets.splice(idx, 1);
|
||||||
|
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
||||||
|
if (node.widgets.length > 1) {
|
||||||
|
node.widgets.splice(node.widgets.length - 1, 1);
|
||||||
|
node.widgets.splice(idx, 0, combo);
|
||||||
|
}
|
||||||
|
return combo;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch file list from API and update file_name combo
|
||||||
|
async function refreshFiles(node) {
|
||||||
|
const urlW = node.widgets?.find(w => w.name === "manager_url");
|
||||||
|
const projW = node.widgets?.find(w => w.name === "project_name");
|
||||||
|
if (!urlW?.value || !projW?.value) return;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const resp = await api.fetchApi(
|
||||||
|
`/json_manager/list_project_files?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}`
|
||||||
|
);
|
||||||
|
if (!resp.ok) return;
|
||||||
|
const data = await resp.json();
|
||||||
|
const fileList = (data.files || []).map(f => f.name || f);
|
||||||
|
console.log(`[ProjectSource] refreshFiles: got ${fileList.length} files:`, fileList);
|
||||||
|
|
||||||
|
const fileW = node.widgets?.find(w => w.name === "file_name");
|
||||||
|
if (fileW) {
|
||||||
|
const currentValue = fileW.value;
|
||||||
|
fileW.options.values = fileList.length > 0 ? fileList : [""];
|
||||||
|
// Keep current selection if still valid
|
||||||
|
if (currentValue && fileList.includes(currentValue)) {
|
||||||
|
fileW.value = currentValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error("[ProjectSource] Failed to refresh files:", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify all ProjectKey nodes referencing this source to re-sync
|
||||||
|
function notifyRelays(sourceNode) {
|
||||||
|
if (!sourceNode.graph?._nodes) return;
|
||||||
|
const labelW = sourceNode.widgets?.find(w => w.name === "label");
|
||||||
|
if (!labelW?.value) return;
|
||||||
|
console.log(`[ProjectSource] notifyRelays: label="${labelW.value}", scanning ${sourceNode.graph._nodes.length} nodes`);
|
||||||
|
let matched = 0;
|
||||||
|
for (const node of sourceNode.graph._nodes) {
|
||||||
|
if (node.type === "ProjectKey" && node._syncFromSource && node._refreshKeys) {
|
||||||
|
const srcW = node.widgets?.find(w => w.name === "source_label");
|
||||||
|
console.log(`[ProjectSource] ProjectKey id=${node.id} source_label="${srcW?.value}"`);
|
||||||
|
if (srcW?.value === labelW.value) {
|
||||||
|
matched++;
|
||||||
|
node._syncFromSource();
|
||||||
|
node._refreshKeys();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
console.log(`[ProjectSource] notifyRelays: matched ${matched} relays`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
origOnNodeCreated?.apply(this, arguments);
|
||||||
|
|
||||||
|
const node = this;
|
||||||
|
|
||||||
|
// Replace file_name STRING with a combo
|
||||||
|
replaceWithCombo(this, "file_name", [], function (value) {
|
||||||
|
notifyRelays(node);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Hook manager_url and project_name to refresh file list + notify relays
|
||||||
|
for (const name of ["manager_url", "project_name"]) {
|
||||||
|
const w = this.widgets?.find(w => w.name === name);
|
||||||
|
if (w) {
|
||||||
|
const origCb = w.callback;
|
||||||
|
w.callback = function (...args) {
|
||||||
|
origCb?.apply(this, args);
|
||||||
|
refreshFiles(node);
|
||||||
|
notifyRelays(node);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hook sequence_number to notify relays
|
||||||
|
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
||||||
|
if (seqW) {
|
||||||
|
const origCb = seqW.callback;
|
||||||
|
seqW.callback = function (...args) {
|
||||||
|
origCb?.apply(this, args);
|
||||||
|
notifyRelays(node);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update title when label changes
|
||||||
|
const labelWidget = this.widgets?.find(w => w.name === "label");
|
||||||
|
if (labelWidget) {
|
||||||
|
const origCallback = labelWidget.callback;
|
||||||
|
labelWidget.callback = function (...args) {
|
||||||
|
origCallback?.apply(this, args);
|
||||||
|
node.title = labelWidget.value
|
||||||
|
? `Source: ${labelWidget.value}`
|
||||||
|
: "Project Source";
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
};
|
||||||
|
// Set initial title
|
||||||
|
if (labelWidget.value) {
|
||||||
|
this.title = `Source: ${labelWidget.value}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function (info) {
|
||||||
|
origOnConfigure?.apply(this, arguments);
|
||||||
|
|
||||||
|
// Ensure file_name is a combo (may be STRING from serialization)
|
||||||
|
const fileW = this.widgets?.find(w => w.name === "file_name");
|
||||||
|
if (fileW && fileW.type !== "combo") {
|
||||||
|
const node = this;
|
||||||
|
replaceWithCombo(this, "file_name", [], function (value) {
|
||||||
|
notifyRelays(node);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const labelWidget = this.widgets?.find(w => w.name === "label");
|
||||||
|
if (labelWidget?.value) {
|
||||||
|
this.title = `Source: ${labelWidget.value}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deferred: refresh file list once graph is ready
|
||||||
|
const node = this;
|
||||||
|
queueMicrotask(() => {
|
||||||
|
refreshFiles(node);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user