589c84fd95
- tab_batch_ng.py: async create_batch with to_thread save/sync - tab_raw_ng.py: async do_save with to_thread, replace deepcopy with dict comprehension for display data - main.py: async create_new with to_thread save - tab_projects_ng.py: replace per-project count_data_files with single list_projects_with_file_counts JOIN query - db.py: add list_projects_with_file_counts method Zero blocking I/O calls remain in UI callbacks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
426 lines
17 KiB
Python
426 lines
17 KiB
Python
import json
|
|
import logging
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db"
|
|
|
|
SCHEMA_SQL = """
|
|
CREATE TABLE IF NOT EXISTS projects (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL UNIQUE,
|
|
folder_path TEXT NOT NULL,
|
|
description TEXT NOT NULL DEFAULT '',
|
|
created_at REAL NOT NULL,
|
|
updated_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS data_files (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
|
|
name TEXT NOT NULL,
|
|
data_type TEXT NOT NULL DEFAULT 'generic',
|
|
top_level TEXT NOT NULL DEFAULT '{}',
|
|
created_at REAL NOT NULL,
|
|
updated_at REAL NOT NULL,
|
|
UNIQUE(project_id, name)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS sequences (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
|
|
sequence_number INTEGER NOT NULL,
|
|
data TEXT NOT NULL DEFAULT '{}',
|
|
updated_at REAL NOT NULL,
|
|
UNIQUE(data_file_id, sequence_number)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS history_trees (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
|
|
tree_data TEXT NOT NULL DEFAULT '{}',
|
|
updated_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_data_files_project_id ON data_files(project_id);
|
|
CREATE INDEX IF NOT EXISTS idx_sequences_data_file_id ON sequences(data_file_id);
|
|
"""
|
|
|
|
|
|
class ProjectDB:
|
|
"""SQLite database for project-based data management."""
|
|
|
|
def __init__(self, db_path: str | Path | None = None):
|
|
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.conn = sqlite3.connect(
|
|
str(self.db_path),
|
|
check_same_thread=False,
|
|
isolation_level=None, # autocommit — explicit BEGIN/COMMIT only
|
|
)
|
|
self.conn.row_factory = sqlite3.Row
|
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
|
self.conn.execute("PRAGMA foreign_keys=ON")
|
|
self.conn.executescript(SCHEMA_SQL)
|
|
|
|
def close(self):
|
|
self.conn.close()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Projects CRUD
|
|
# ------------------------------------------------------------------
|
|
|
|
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
|
|
now = time.time()
|
|
cur = self.conn.execute(
|
|
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?)",
|
|
(name, folder_path, description, now, now),
|
|
)
|
|
self.conn.commit()
|
|
return cur.lastrowid
|
|
|
|
def list_projects(self) -> list[dict]:
|
|
rows = self.conn.execute(
|
|
"SELECT id, name, folder_path, description, created_at, updated_at "
|
|
"FROM projects ORDER BY name"
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def list_projects_with_file_counts(self) -> list[dict]:
|
|
"""List projects with data file counts in a single query."""
|
|
rows = self.conn.execute(
|
|
"SELECT p.id, p.name, p.folder_path, p.description, p.created_at, p.updated_at, "
|
|
"COUNT(df.id) AS file_count "
|
|
"FROM projects p LEFT JOIN data_files df ON df.project_id = p.id "
|
|
"GROUP BY p.id ORDER BY p.name"
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def get_project(self, name: str) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT id, name, folder_path, description, created_at, updated_at "
|
|
"FROM projects WHERE name = ?",
|
|
(name,),
|
|
).fetchone()
|
|
return dict(row) if row else None
|
|
|
|
def rename_project(self, old_name: str, new_name: str) -> bool:
|
|
now = time.time()
|
|
cur = self.conn.execute(
|
|
"UPDATE projects SET name = ?, updated_at = ? WHERE name = ?",
|
|
(new_name, now, old_name),
|
|
)
|
|
self.conn.commit()
|
|
return cur.rowcount > 0
|
|
|
|
def update_project_path(self, name: str, folder_path: str) -> bool:
|
|
now = time.time()
|
|
cur = self.conn.execute(
|
|
"UPDATE projects SET folder_path = ?, updated_at = ? WHERE name = ?",
|
|
(folder_path, now, name),
|
|
)
|
|
self.conn.commit()
|
|
return cur.rowcount > 0
|
|
|
|
def delete_project(self, name: str) -> bool:
|
|
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
|
|
self.conn.commit()
|
|
return cur.rowcount > 0
|
|
|
|
# ------------------------------------------------------------------
|
|
# Data files
|
|
# ------------------------------------------------------------------
|
|
|
|
def create_data_file(
|
|
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
|
|
) -> int:
|
|
now = time.time()
|
|
tl = json.dumps(top_level or {})
|
|
cur = self.conn.execute(
|
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
(project_id, name, data_type, tl, now, now),
|
|
)
|
|
self.conn.commit()
|
|
return cur.lastrowid
|
|
|
|
def list_data_files(self, project_id: int) -> list[dict]:
|
|
rows = self.conn.execute(
|
|
"SELECT id, project_id, name, data_type, created_at, updated_at "
|
|
"FROM data_files WHERE project_id = ? ORDER BY name",
|
|
(project_id,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def count_data_files(self, project_id: int) -> int:
|
|
"""Return the number of data files for a project."""
|
|
row = self.conn.execute(
|
|
"SELECT COUNT(*) AS cnt FROM data_files WHERE project_id = ?",
|
|
(project_id,),
|
|
).fetchone()
|
|
return row["cnt"]
|
|
|
|
def get_data_file(self, project_id: int, name: str) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
|
|
"FROM data_files WHERE project_id = ? AND name = ?",
|
|
(project_id, name),
|
|
).fetchone()
|
|
if row is None:
|
|
return None
|
|
d = dict(row)
|
|
d["top_level"] = json.loads(d["top_level"])
|
|
return d
|
|
|
|
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
|
|
"df.created_at, df.updated_at "
|
|
"FROM data_files df JOIN projects p ON df.project_id = p.id "
|
|
"WHERE p.name = ? AND df.name = ?",
|
|
(project_name, file_name),
|
|
).fetchone()
|
|
if row is None:
|
|
return None
|
|
d = dict(row)
|
|
d["top_level"] = json.loads(d["top_level"])
|
|
return d
|
|
|
|
# ------------------------------------------------------------------
|
|
# Sequences
|
|
# ------------------------------------------------------------------
|
|
|
|
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
|
|
now = time.time()
|
|
self.conn.execute(
|
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
|
"VALUES (?, ?, ?, ?) "
|
|
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
|
(data_file_id, sequence_number, json.dumps(data), now),
|
|
)
|
|
self.conn.commit()
|
|
|
|
@staticmethod
|
|
def _migrate_lora_keys(data: dict) -> dict:
|
|
"""Split legacy <lora:name:strength> values into separate name/strength keys."""
|
|
for idx in range(1, 4):
|
|
for tier in ('high', 'low'):
|
|
name_key = f'lora {idx} {tier}'
|
|
str_key = f'lora {idx} {tier} strength'
|
|
raw = str(data.get(name_key, ''))
|
|
if raw.startswith('<lora:'):
|
|
inner = raw.replace('<lora:', '').replace('>', '')
|
|
if ':' in inner:
|
|
parts = inner.rsplit(':', 1)
|
|
data[name_key] = parts[0]
|
|
try:
|
|
data[str_key] = float(parts[1])
|
|
except ValueError:
|
|
data.setdefault(str_key, 1.0)
|
|
else:
|
|
data[name_key] = inner
|
|
data.setdefault(str_key, 1.0)
|
|
elif name_key in data and str_key not in data:
|
|
data[str_key] = 1.0
|
|
# Ensure strength is always a float (JSON may deserialize 1 as int)
|
|
if str_key in data:
|
|
data[str_key] = float(data[str_key])
|
|
return data
|
|
|
|
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
|
|
(data_file_id, sequence_number),
|
|
).fetchone()
|
|
if not row:
|
|
return None
|
|
data = json.loads(row["data"])
|
|
return self._migrate_lora_keys(data)
|
|
|
|
def list_sequences(self, data_file_id: int) -> list[int]:
|
|
rows = self.conn.execute(
|
|
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
|
|
(data_file_id,),
|
|
).fetchall()
|
|
return [r["sequence_number"] for r in rows]
|
|
|
|
def count_sequences(self, data_file_id: int) -> int:
|
|
"""Return the number of sequences for a data file."""
|
|
row = self.conn.execute(
|
|
"SELECT COUNT(*) AS cnt FROM sequences WHERE data_file_id = ?",
|
|
(data_file_id,),
|
|
).fetchone()
|
|
return row["cnt"]
|
|
|
|
def query_total_sequences(self, project_name: str, file_name: str) -> int:
|
|
"""Return total sequence count by project and file names."""
|
|
df = self.get_data_file_by_names(project_name, file_name)
|
|
if not df:
|
|
return 0
|
|
return self.count_sequences(df["id"])
|
|
|
|
_FLOAT_KEYS = frozenset(
|
|
f'lora {idx} {tier} strength'
|
|
for idx in range(1, 4) for tier in ('high', 'low')
|
|
)
|
|
|
|
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
|
|
"""Returns (keys, types) for a sequence's data dict."""
|
|
data = self.get_sequence(data_file_id, sequence_number)
|
|
if not data:
|
|
return [], []
|
|
keys = []
|
|
types = []
|
|
for k, v in data.items():
|
|
keys.append(k)
|
|
if k in self._FLOAT_KEYS:
|
|
types.append("FLOAT")
|
|
elif isinstance(v, bool):
|
|
types.append("STRING")
|
|
elif isinstance(v, int):
|
|
types.append("INT")
|
|
elif isinstance(v, float):
|
|
types.append("FLOAT")
|
|
else:
|
|
types.append("STRING")
|
|
return keys, types
|
|
|
|
def delete_sequences_for_file(self, data_file_id: int) -> None:
|
|
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
|
|
self.conn.commit()
|
|
|
|
# ------------------------------------------------------------------
|
|
# History trees
|
|
# ------------------------------------------------------------------
|
|
|
|
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
|
now = time.time()
|
|
self.conn.execute(
|
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
|
"VALUES (?, ?, ?) "
|
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
|
(data_file_id, json.dumps(tree_data), now),
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_history_tree(self, data_file_id: int) -> dict | None:
|
|
row = self.conn.execute(
|
|
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
|
|
(data_file_id,),
|
|
).fetchone()
|
|
return json.loads(row["tree_data"]) if row else None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Import
|
|
# ------------------------------------------------------------------
|
|
|
|
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
|
"""Import a JSON file into the database, splitting batch_data into sequences.
|
|
|
|
Safe to call repeatedly — existing data_file is updated, sequences are
|
|
replaced, and history_tree is upserted. Atomic: all-or-nothing.
|
|
"""
|
|
json_path = Path(json_path)
|
|
data, _ = load_json(json_path)
|
|
file_name = json_path.stem
|
|
|
|
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
|
|
|
self.conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
existing = self.conn.execute(
|
|
"SELECT id FROM data_files WHERE project_id = ? AND name = ?",
|
|
(project_id, file_name),
|
|
).fetchone()
|
|
|
|
if existing:
|
|
df_id = existing["id"]
|
|
now = time.time()
|
|
self.conn.execute(
|
|
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
|
(data_type, json.dumps(top_level), now, df_id),
|
|
)
|
|
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
|
else:
|
|
now = time.time()
|
|
cur = self.conn.execute(
|
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
(project_id, file_name, data_type, json.dumps(top_level), now, now),
|
|
)
|
|
df_id = cur.lastrowid
|
|
|
|
# Import sequences from batch_data
|
|
batch_data = data.get(KEY_BATCH_DATA, [])
|
|
if isinstance(batch_data, list):
|
|
for item in batch_data:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
seq_num = int(item.get("sequence_number", 0))
|
|
now = time.time()
|
|
self.conn.execute(
|
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
|
"VALUES (?, ?, ?, ?) "
|
|
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
|
|
(df_id, seq_num, json.dumps(item), now),
|
|
)
|
|
|
|
# Import history tree
|
|
history_tree = data.get(KEY_HISTORY_TREE)
|
|
if history_tree and isinstance(history_tree, dict):
|
|
now = time.time()
|
|
self.conn.execute(
|
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
|
"VALUES (?, ?, ?) "
|
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
|
(df_id, json.dumps(history_tree), now),
|
|
)
|
|
|
|
self.conn.execute("COMMIT")
|
|
return df_id
|
|
except Exception:
|
|
try:
|
|
self.conn.execute("ROLLBACK")
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
# ------------------------------------------------------------------
|
|
# Query helpers (for REST API)
|
|
# ------------------------------------------------------------------
|
|
|
|
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
|
|
"""Query a single sequence by project name, file name, and sequence number."""
|
|
df = self.get_data_file_by_names(project_name, file_name)
|
|
if not df:
|
|
return None
|
|
return self.get_sequence(df["id"], sequence_number)
|
|
|
|
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
|
|
"""Query keys and types for a sequence."""
|
|
df = self.get_data_file_by_names(project_name, file_name)
|
|
if not df:
|
|
return [], []
|
|
return self.get_sequence_keys(df["id"], sequence_number)
|
|
|
|
def list_project_files(self, project_name: str) -> list[dict]:
|
|
"""List data files for a project by name."""
|
|
proj = self.get_project(project_name)
|
|
if not proj:
|
|
return []
|
|
return self.list_data_files(proj["id"])
|
|
|
|
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
|
|
"""List sequence numbers for a file in a project."""
|
|
df = self.get_data_file_by_names(project_name, file_name)
|
|
if not df:
|
|
return []
|
|
return self.list_sequences(df["id"])
|