Files
Comfyui-JSON-Manager/db.py
Ethanfel d07a308865 Harden ROLLBACK against I/O errors in transactions
If the original error (e.g., disk full) also prevents ROLLBACK from
executing, catch and suppress the ROLLBACK failure so the original
exception propagates cleanly and the connection isn't left in a
permanently broken state.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-28 21:44:12 +01:00

335 lines
13 KiB
Python

import json
import logging
import sqlite3
import time
from pathlib import Path
from typing import Any
from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
logger = logging.getLogger(__name__)
DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db"
SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
folder_path TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS data_files (
id INTEGER PRIMARY KEY AUTOINCREMENT,
project_id INTEGER NOT NULL REFERENCES projects(id) ON DELETE CASCADE,
name TEXT NOT NULL,
data_type TEXT NOT NULL DEFAULT 'generic',
top_level TEXT NOT NULL DEFAULT '{}',
created_at REAL NOT NULL,
updated_at REAL NOT NULL,
UNIQUE(project_id, name)
);
CREATE TABLE IF NOT EXISTS sequences (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data_file_id INTEGER NOT NULL REFERENCES data_files(id) ON DELETE CASCADE,
sequence_number INTEGER NOT NULL,
data TEXT NOT NULL DEFAULT '{}',
updated_at REAL NOT NULL,
UNIQUE(data_file_id, sequence_number)
);
CREATE TABLE IF NOT EXISTS history_trees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data_file_id INTEGER NOT NULL UNIQUE REFERENCES data_files(id) ON DELETE CASCADE,
tree_data TEXT NOT NULL DEFAULT '{}',
updated_at REAL NOT NULL
);
"""
class ProjectDB:
"""SQLite database for project-based data management."""
def __init__(self, db_path: str | Path | None = None):
self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.conn = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
isolation_level=None, # autocommit — explicit BEGIN/COMMIT only
)
self.conn.row_factory = sqlite3.Row
self.conn.execute("PRAGMA journal_mode=WAL")
self.conn.execute("PRAGMA foreign_keys=ON")
self.conn.executescript(SCHEMA_SQL)
def close(self):
self.conn.close()
# ------------------------------------------------------------------
# Projects CRUD
# ------------------------------------------------------------------
def create_project(self, name: str, folder_path: str, description: str = "") -> int:
now = time.time()
cur = self.conn.execute(
"INSERT INTO projects (name, folder_path, description, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?)",
(name, folder_path, description, now, now),
)
self.conn.commit()
return cur.lastrowid
def list_projects(self) -> list[dict]:
rows = self.conn.execute(
"SELECT id, name, folder_path, description, created_at, updated_at "
"FROM projects ORDER BY name"
).fetchall()
return [dict(r) for r in rows]
def get_project(self, name: str) -> dict | None:
row = self.conn.execute(
"SELECT id, name, folder_path, description, created_at, updated_at "
"FROM projects WHERE name = ?",
(name,),
).fetchone()
return dict(row) if row else None
def delete_project(self, name: str) -> bool:
cur = self.conn.execute("DELETE FROM projects WHERE name = ?", (name,))
self.conn.commit()
return cur.rowcount > 0
# ------------------------------------------------------------------
# Data files
# ------------------------------------------------------------------
def create_data_file(
self, project_id: int, name: str, data_type: str = "generic", top_level: dict | None = None
) -> int:
now = time.time()
tl = json.dumps(top_level or {})
cur = self.conn.execute(
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(project_id, name, data_type, tl, now, now),
)
self.conn.commit()
return cur.lastrowid
def list_data_files(self, project_id: int) -> list[dict]:
rows = self.conn.execute(
"SELECT id, project_id, name, data_type, created_at, updated_at "
"FROM data_files WHERE project_id = ? ORDER BY name",
(project_id,),
).fetchall()
return [dict(r) for r in rows]
def get_data_file(self, project_id: int, name: str) -> dict | None:
row = self.conn.execute(
"SELECT id, project_id, name, data_type, top_level, created_at, updated_at "
"FROM data_files WHERE project_id = ? AND name = ?",
(project_id, name),
).fetchone()
if row is None:
return None
d = dict(row)
d["top_level"] = json.loads(d["top_level"])
return d
def get_data_file_by_names(self, project_name: str, file_name: str) -> dict | None:
row = self.conn.execute(
"SELECT df.id, df.project_id, df.name, df.data_type, df.top_level, "
"df.created_at, df.updated_at "
"FROM data_files df JOIN projects p ON df.project_id = p.id "
"WHERE p.name = ? AND df.name = ?",
(project_name, file_name),
).fetchone()
if row is None:
return None
d = dict(row)
d["top_level"] = json.loads(d["top_level"])
return d
# ------------------------------------------------------------------
# Sequences
# ------------------------------------------------------------------
def upsert_sequence(self, data_file_id: int, sequence_number: int, data: dict) -> None:
now = time.time()
self.conn.execute(
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
"VALUES (?, ?, ?, ?) "
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
(data_file_id, sequence_number, json.dumps(data), now),
)
self.conn.commit()
def get_sequence(self, data_file_id: int, sequence_number: int) -> dict | None:
row = self.conn.execute(
"SELECT data FROM sequences WHERE data_file_id = ? AND sequence_number = ?",
(data_file_id, sequence_number),
).fetchone()
return json.loads(row["data"]) if row else None
def list_sequences(self, data_file_id: int) -> list[int]:
rows = self.conn.execute(
"SELECT sequence_number FROM sequences WHERE data_file_id = ? ORDER BY sequence_number",
(data_file_id,),
).fetchall()
return [r["sequence_number"] for r in rows]
def get_sequence_keys(self, data_file_id: int, sequence_number: int) -> tuple[list[str], list[str]]:
"""Returns (keys, types) for a sequence's data dict."""
data = self.get_sequence(data_file_id, sequence_number)
if not data:
return [], []
keys = []
types = []
for k, v in data.items():
keys.append(k)
if isinstance(v, bool):
types.append("STRING")
elif isinstance(v, int):
types.append("INT")
elif isinstance(v, float):
types.append("FLOAT")
else:
types.append("STRING")
return keys, types
def delete_sequences_for_file(self, data_file_id: int) -> None:
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (data_file_id,))
self.conn.commit()
# ------------------------------------------------------------------
# History trees
# ------------------------------------------------------------------
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
now = time.time()
self.conn.execute(
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
"VALUES (?, ?, ?) "
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
(data_file_id, json.dumps(tree_data), now),
)
self.conn.commit()
def get_history_tree(self, data_file_id: int) -> dict | None:
row = self.conn.execute(
"SELECT tree_data FROM history_trees WHERE data_file_id = ?",
(data_file_id,),
).fetchone()
return json.loads(row["tree_data"]) if row else None
# ------------------------------------------------------------------
# Import
# ------------------------------------------------------------------
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
"""Import a JSON file into the database, splitting batch_data into sequences.
Safe to call repeatedly — existing data_file is updated, sequences are
replaced, and history_tree is upserted. Atomic: all-or-nothing.
"""
json_path = Path(json_path)
data, _ = load_json(json_path)
file_name = json_path.stem
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
self.conn.execute("BEGIN IMMEDIATE")
try:
existing = self.conn.execute(
"SELECT id FROM data_files WHERE project_id = ? AND name = ?",
(project_id, file_name),
).fetchone()
if existing:
df_id = existing["id"]
now = time.time()
self.conn.execute(
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
(data_type, json.dumps(top_level), now, df_id),
)
self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
else:
now = time.time()
cur = self.conn.execute(
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(project_id, file_name, data_type, json.dumps(top_level), now, now),
)
df_id = cur.lastrowid
# Import sequences from batch_data
batch_data = data.get(KEY_BATCH_DATA, [])
if isinstance(batch_data, list):
for item in batch_data:
if not isinstance(item, dict):
continue
seq_num = int(item.get("sequence_number", 0))
now = time.time()
self.conn.execute(
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
"VALUES (?, ?, ?, ?) "
"ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at",
(df_id, seq_num, json.dumps(item), now),
)
# Import history tree
history_tree = data.get(KEY_HISTORY_TREE)
if history_tree and isinstance(history_tree, dict):
now = time.time()
self.conn.execute(
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
"VALUES (?, ?, ?) "
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
(df_id, json.dumps(history_tree), now),
)
self.conn.execute("COMMIT")
return df_id
except Exception:
try:
self.conn.execute("ROLLBACK")
except Exception:
pass
raise
# ------------------------------------------------------------------
# Query helpers (for REST API)
# ------------------------------------------------------------------
def query_sequence_data(self, project_name: str, file_name: str, sequence_number: int) -> dict | None:
"""Query a single sequence by project name, file name, and sequence number."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return None
return self.get_sequence(df["id"], sequence_number)
def query_sequence_keys(self, project_name: str, file_name: str, sequence_number: int) -> tuple[list[str], list[str]]:
"""Query keys and types for a sequence."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return [], []
return self.get_sequence_keys(df["id"], sequence_number)
def list_project_files(self, project_name: str) -> list[dict]:
"""List data files for a project by name."""
proj = self.get_project(project_name)
if not proj:
return []
return self.list_data_files(proj["id"])
def list_project_sequences(self, project_name: str, file_name: str) -> list[int]:
"""List sequence numbers for a file in a project."""
df = self.get_data_file_by_names(project_name, file_name)
if not df:
return []
return self.list_sequences(df["id"])