Fix 6 bugs found during code review

- Fix NameError: pass state to _render_vace_settings (tab_batch_ng.py)
- Fix non-atomic sync_to_db: use BEGIN IMMEDIATE transaction with rollback
- Fix create_secondary() missing db/current_project/db_enabled fields
- Fix URL encoding: percent-encode project/file names in API URLs
- Fix import_json_file crash on re-import: upsert instead of insert
- Fix dual DB instances: share single ProjectDB between UI and API routes
- Also fixes top_level metadata never being updated on existing data_files

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-28 21:25:31 +01:00
parent 6b7e9ea682
commit ba8f104bc1
8 changed files with 131 additions and 39 deletions

21
db.py
View File

@@ -228,15 +228,30 @@ class ProjectDB:
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
"""Import a JSON file into the database, splitting batch_data into sequences.""" """Import a JSON file into the database, splitting batch_data into sequences.
Safe to call repeatedly — existing data_file is updated, sequences are
replaced, and history_tree is upserted.
"""
json_path = Path(json_path) json_path = Path(json_path)
data, _ = load_json(json_path) data, _ = load_json(json_path)
file_name = json_path.stem file_name = json_path.stem
# Extract top-level keys that aren't batch_data or history_tree
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
df_id = self.create_data_file(project_id, file_name, data_type, top_level) existing = self.get_data_file(project_id, file_name)
if existing:
df_id = existing["id"]
now = time.time()
self.conn.execute(
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
(data_type, json.dumps(top_level), now, df_id),
)
self.conn.commit()
# Clear old sequences before re-importing
self.delete_sequences_for_file(df_id)
else:
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
# Import sequences from batch_data # Import sequences from batch_data
batch_data = data.get(KEY_BATCH_DATA, []) batch_data = data.get(KEY_BATCH_DATA, [])

24
main.py
View File

@@ -21,6 +21,13 @@ from api_routes import register_api_routes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Single shared DB instance for both the UI and API routes
_shared_db: ProjectDB | None = None
try:
_shared_db = ProjectDB()
except Exception as e:
logger.warning(f"Failed to initialize ProjectDB: {e}")
@ui.page('/') @ui.page('/')
def index(): def index():
@@ -166,12 +173,8 @@ def index():
current_project=config.get('current_project', ''), current_project=config.get('current_project', ''),
) )
# Initialize project database # Use the shared DB instance
try: state.db = _shared_db
state.db = ProjectDB()
except Exception as e:
logger.warning(f"Failed to initialize ProjectDB: {e}")
state.db = None
dual_pane = {'active': False, 'state': None} dual_pane = {'active': False, 'state': None}
@@ -500,11 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict):
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
# Register REST API routes for ComfyUI connectivity # Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
try: if _shared_db is not None:
_api_db = ProjectDB() register_api_routes(_shared_db)
register_api_routes(_api_db)
except Exception as e:
logger.warning(f"Failed to register API routes: {e}")
ui.run(title='AI Settings Manager', port=8080, reload=True) ui.run(title='AI Settings Manager', port=8080, reload=True)

View File

@@ -1,5 +1,6 @@
import json import json
import logging import logging
import urllib.parse
import urllib.request import urllib.request
import urllib.error import urllib.error
from typing import Any from typing import Any
@@ -49,13 +50,17 @@ def _fetch_json(url: str) -> dict:
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
"""Fetch sequence data from the NiceGUI REST API.""" """Fetch sequence data from the NiceGUI REST API."""
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}" p = urllib.parse.quote(project, safe='')
f = urllib.parse.quote(file, safe='')
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
return _fetch_json(url) return _fetch_json(url)
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
"""Fetch keys/types from the NiceGUI REST API.""" """Fetch keys/types from the NiceGUI REST API."""
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}" p = urllib.parse.quote(project, safe='')
f = urllib.parse.quote(file, safe='')
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
return _fetch_json(url) return _fetch_json(url)
@@ -71,7 +76,7 @@ if PromptServer is not None:
@PromptServer.instance.routes.get("/json_manager/list_project_files") @PromptServer.instance.routes.get("/json_manager/list_project_files")
async def list_project_files_proxy(request): async def list_project_files_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080") manager_url = request.query.get("url", "http://localhost:8080")
project = request.query.get("project", "") project = urllib.parse.quote(request.query.get("project", ""), safe='')
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
data = _fetch_json(url) data = _fetch_json(url)
return web.json_response(data) return web.json_response(data)
@@ -79,8 +84,8 @@ if PromptServer is not None:
@PromptServer.instance.routes.get("/json_manager/list_project_sequences") @PromptServer.instance.routes.get("/json_manager/list_project_sequences")
async def list_project_sequences_proxy(request): async def list_project_sequences_proxy(request):
manager_url = request.query.get("url", "http://localhost:8080") manager_url = request.query.get("url", "http://localhost:8080")
project = request.query.get("project", "") project = urllib.parse.quote(request.query.get("project", ""), safe='')
file_name = request.query.get("file", "") file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
data = _fetch_json(url) data = _fetch_json(url)
return web.json_response(data) return web.json_response(data)
@@ -98,6 +103,7 @@ if PromptServer is not None:
return web.json_response(data) return web.json_response(data)
# ========================================== # ==========================================
# 0. DYNAMIC NODE (Project-based) # 0. DYNAMIC NODE (Project-based)
# ========================================== # ==========================================

View File

@@ -34,4 +34,7 @@ class AppState:
config=self.config, config=self.config,
current_dir=self.current_dir, current_dir=self.current_dir,
snippets=self.snippets, snippets=self.snippets,
db=self.db,
current_project=self.current_project,
db_enabled=self.db_enabled,
) )

View File

@@ -457,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
# --- VACE Settings (full width) --- # --- VACE Settings (full width) ---
with ui.expansion('VACE Settings', icon='settings').classes('w-full'): with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
_render_vace_settings(i, seq, batch_list, data, file_path, refresh_list) _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
# --- LoRA Settings --- # --- LoRA Settings ---
with ui.expansion('LoRA Settings', icon='style').classes('w-full'): with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
@@ -539,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
# VACE Settings sub-section # VACE Settings sub-section
# ====================================================================== # ======================================================================
def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
# VACE Schedule (needed early for both columns) # VACE Schedule (needed early for both columns)
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1)) sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))

View File

@@ -246,6 +246,32 @@ class TestImport:
seqs = db.list_sequences(df_id) seqs = db.list_sequences(df_id)
assert seqs == [] assert seqs == []
def test_reimport_updates_existing(self, db, tmp_path):
"""Re-importing the same file should update data, not crash."""
pid = db.create_project("p1", "/p1")
json_path = tmp_path / "batch.json"
# First import
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
json_path.write_text(json.dumps(data_v1))
df_id_1 = db.import_json_file(pid, json_path, "i2v")
# Second import (same file, updated data)
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
json_path.write_text(json.dumps(data_v2))
df_id_2 = db.import_json_file(pid, json_path, "vace")
# Should reuse the same data_file row
assert df_id_1 == df_id_2
# Data type should be updated
df = db.get_data_file(pid, "batch")
assert df["data_type"] == "vace"
# Sequences should reflect v2
seqs = db.list_sequences(df_id_2)
assert seqs == [1, 2]
s1 = db.get_sequence(df_id_2, 1)
assert s1["prompt"] == "v2"
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Query helpers # Query helpers

View File

@@ -61,6 +61,16 @@ class TestFetchHelpers:
called_url = mock.call_args[0][0] called_url = mock.call_args[0][0]
assert "//api" not in called_url assert "//api" not in called_url
def test_fetch_data_encodes_special_chars(self):
"""Project/file names with spaces or special chars should be percent-encoded."""
data = {"prompt": "hello"}
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
called_url = mock.call_args[0][0]
assert "my%20project" in called_url
assert "batch%20file" in called_url
assert " " not in called_url.split("?")[0] # no raw spaces in path
class TestProjectLoaderDynamic: class TestProjectLoaderDynamic:
def test_load_dynamic_with_keys(self): def test_load_dynamic_with_keys(self):

View File

@@ -164,7 +164,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
"""Dual-write helper: sync JSON data to the project database. """Dual-write helper: sync JSON data to the project database.
Resolves (or creates) the data_file, upserts all sequences from batch_data, Resolves (or creates) the data_file, upserts all sequences from batch_data,
and saves the history_tree. and saves the history_tree. All writes happen in a single transaction.
""" """
if not db or not project_name: if not db or not project_name:
return return
@@ -173,26 +173,58 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
if not proj: if not proj:
return return
file_name = Path(file_path).stem file_name = Path(file_path).stem
df = db.get_data_file(proj["id"], file_name)
if not df: # Use a single transaction for atomicity
db.conn.execute("BEGIN IMMEDIATE")
try:
df = db.get_data_file(proj["id"], file_name)
top_level = {k: v for k, v in data.items() top_level = {k: v for k, v in data.items()
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
df_id = db.create_data_file(proj["id"], file_name, "generic", top_level) if not df:
else: now = __import__('time').time()
df_id = df["id"] cur = db.conn.execute(
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
)
df_id = cur.lastrowid
else:
df_id = df["id"]
# Update top_level metadata
now = __import__('time').time()
db.conn.execute(
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
(json.dumps(top_level), now, df_id),
)
# Sync sequences # Sync sequences
batch_data = data.get(KEY_BATCH_DATA, []) batch_data = data.get(KEY_BATCH_DATA, [])
if isinstance(batch_data, list): if isinstance(batch_data, list):
db.delete_sequences_for_file(df_id) db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
for item in batch_data: for item in batch_data:
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
db.upsert_sequence(df_id, seq_num, item) now = __import__('time').time()
db.conn.execute(
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
"VALUES (?, ?, ?, ?)",
(df_id, seq_num, json.dumps(item), now),
)
# Sync history tree # Sync history tree
history_tree = data.get(KEY_HISTORY_TREE) history_tree = data.get(KEY_HISTORY_TREE)
if history_tree and isinstance(history_tree, dict): if history_tree and isinstance(history_tree, dict):
db.save_history_tree(df_id, history_tree) now = __import__('time').time()
db.conn.execute(
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
"VALUES (?, ?, ?) "
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
(df_id, json.dumps(history_tree), now),
)
db.conn.commit()
except Exception:
db.conn.rollback()
raise
except Exception as e: except Exception as e:
logger.warning(f"sync_to_db failed: {e}") logger.warning(f"sync_to_db failed: {e}")