From ba8f104bc1310a4d67f8221d380de88a1875ba40 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:25:31 +0100 Subject: [PATCH] Fix 6 bugs found during code review - Fix NameError: pass state to _render_vace_settings (tab_batch_ng.py) - Fix non-atomic sync_to_db: use BEGIN IMMEDIATE transaction with rollback - Fix create_secondary() missing db/current_project/db_enabled fields - Fix URL encoding: percent-encode project/file names in API URLs - Fix import_json_file crash on re-import: upsert instead of insert - Fix dual DB instances: share single ProjectDB between UI and API routes - Also fixes top_level metadata never being updated on existing data_files Co-Authored-By: Claude Opus 4.6 --- db.py | 21 ++++++++++-- main.py | 24 ++++++------- project_loader.py | 16 ++++++--- state.py | 3 ++ tab_batch_ng.py | 4 +-- tests/test_db.py | 26 ++++++++++++++ tests/test_project_loader.py | 10 ++++++ utils.py | 66 ++++++++++++++++++++++++++---------- 8 files changed, 131 insertions(+), 39 deletions(-) diff --git a/db.py b/db.py index 11efc7d..b9b17a4 100644 --- a/db.py +++ b/db.py @@ -228,15 +228,30 @@ class ProjectDB: # ------------------------------------------------------------------ def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int: - """Import a JSON file into the database, splitting batch_data into sequences.""" + """Import a JSON file into the database, splitting batch_data into sequences. + + Safe to call repeatedly — existing data_file is updated, sequences are + replaced, and history_tree is upserted. + """ json_path = Path(json_path) data, _ = load_json(json_path) file_name = json_path.stem - # Extract top-level keys that aren't batch_data or history_tree top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - df_id = self.create_data_file(project_id, file_name, data_type, top_level) + existing = self.get_data_file(project_id, file_name) + if existing: + df_id = existing["id"] + now = time.time() + self.conn.execute( + "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", + (data_type, json.dumps(top_level), now, df_id), + ) + self.conn.commit() + # Clear old sequences before re-importing + self.delete_sequences_for_file(df_id) + else: + df_id = self.create_data_file(project_id, file_name, data_type, top_level) # Import sequences from batch_data batch_data = data.get(KEY_BATCH_DATA, []) diff --git a/main.py b/main.py index 5936aa8..efa2619 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,13 @@ from api_routes import register_api_routes logger = logging.getLogger(__name__) +# Single shared DB instance for both the UI and API routes +_shared_db: ProjectDB | None = None +try: + _shared_db = ProjectDB() +except Exception as e: + logger.warning(f"Failed to initialize ProjectDB: {e}") + @ui.page('/') def index(): @@ -166,12 +173,8 @@ def index(): current_project=config.get('current_project', ''), ) - # Initialize project database - try: - state.db = ProjectDB() - except Exception as e: - logger.warning(f"Failed to initialize ProjectDB: {e}") - state.db = None + # Use the shared DB instance + state.db = _shared_db dual_pane = {'active': False, 'state': None} @@ -500,11 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict): ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle) -# Register REST API routes for ComfyUI connectivity -try: - _api_db = ProjectDB() - register_api_routes(_api_db) -except Exception as e: - logger.warning(f"Failed to register API routes: {e}") +# Register REST API routes for ComfyUI connectivity (uses the shared DB instance) +if _shared_db is not None: + register_api_routes(_shared_db) ui.run(title='AI Settings Manager', port=8080, reload=True) diff --git a/project_loader.py b/project_loader.py index 2adc6c3..1634135 100644 --- a/project_loader.py +++ b/project_loader.py @@ -1,5 +1,6 @@ import json import logging +import urllib.parse import urllib.request import urllib.error from typing import Any @@ -49,13 +50,17 @@ def _fetch_json(url: str) -> dict: def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: """Fetch sequence data from the NiceGUI REST API.""" - url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}" return _fetch_json(url) def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict: """Fetch keys/types from the NiceGUI REST API.""" - url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}" + p = urllib.parse.quote(project, safe='') + f = urllib.parse.quote(file, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}" return _fetch_json(url) @@ -71,7 +76,7 @@ if PromptServer is not None: @PromptServer.instance.routes.get("/json_manager/list_project_files") async def list_project_files_proxy(request): manager_url = request.query.get("url", "http://localhost:8080") - project = request.query.get("project", "") + project = urllib.parse.quote(request.query.get("project", ""), safe='') url = f"{manager_url.rstrip('/')}/api/projects/{project}/files" data = _fetch_json(url) return web.json_response(data) @@ -79,8 +84,8 @@ if PromptServer is not None: @PromptServer.instance.routes.get("/json_manager/list_project_sequences") async def list_project_sequences_proxy(request): manager_url = request.query.get("url", "http://localhost:8080") - project = request.query.get("project", "") - file_name = request.query.get("file", "") + project = urllib.parse.quote(request.query.get("project", ""), safe='') + file_name = urllib.parse.quote(request.query.get("file", ""), safe='') url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences" data = _fetch_json(url) return web.json_response(data) @@ -98,6 +103,7 @@ if PromptServer is not None: return web.json_response(data) + # ========================================== # 0. DYNAMIC NODE (Project-based) # ========================================== diff --git a/state.py b/state.py index 891a14e..bef8818 100644 --- a/state.py +++ b/state.py @@ -34,4 +34,7 @@ class AppState: config=self.config, current_dir=self.current_dir, snippets=self.snippets, + db=self.db, + current_project=self.current_project, + db_enabled=self.db_enabled, ) diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 36abee8..7ec49ae 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -457,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # --- VACE Settings (full width) --- with ui.expansion('VACE Settings', icon='settings').classes('w-full'): - _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list) + _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list) # --- LoRA Settings --- with ui.expansion('LoRA Settings', icon='style').classes('w-full'): @@ -539,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # VACE Settings sub-section # ====================================================================== -def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list): +def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list): # VACE Schedule (needed early for both columns) sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1)) diff --git a/tests/test_db.py b/tests/test_db.py index 73427c1..a0dc7ab 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -246,6 +246,32 @@ class TestImport: seqs = db.list_sequences(df_id) assert seqs == [] + def test_reimport_updates_existing(self, db, tmp_path): + """Re-importing the same file should update data, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + + # First import + data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]} + json_path.write_text(json.dumps(data_v1)) + df_id_1 = db.import_json_file(pid, json_path, "i2v") + + # Second import (same file, updated data) + data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]} + json_path.write_text(json.dumps(data_v2)) + df_id_2 = db.import_json_file(pid, json_path, "vace") + + # Should reuse the same data_file row + assert df_id_1 == df_id_2 + # Data type should be updated + df = db.get_data_file(pid, "batch") + assert df["data_type"] == "vace" + # Sequences should reflect v2 + seqs = db.list_sequences(df_id_2) + assert seqs == [1, 2] + s1 = db.get_sequence(df_id_2, 1) + assert s1["prompt"] == "v2" + # ------------------------------------------------------------------ # Query helpers diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 00a59a7..fd77eb9 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -61,6 +61,16 @@ class TestFetchHelpers: called_url = mock.call_args[0][0] assert "//api" not in called_url + def test_fetch_data_encodes_special_chars(self): + """Project/file names with spaces or special chars should be percent-encoded.""" + data = {"prompt": "hello"} + with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock: + _fetch_data("http://localhost:8080", "my project", "batch file", 1) + called_url = mock.call_args[0][0] + assert "my%20project" in called_url + assert "batch%20file" in called_url + assert " " not in called_url.split("?")[0] # no raw spaces in path + class TestProjectLoaderDynamic: def test_load_dynamic_with_keys(self): diff --git a/utils.py b/utils.py index ea5cdc9..44707a2 100644 --- a/utils.py +++ b/utils.py @@ -164,7 +164,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: """Dual-write helper: sync JSON data to the project database. Resolves (or creates) the data_file, upserts all sequences from batch_data, - and saves the history_tree. + and saves the history_tree. All writes happen in a single transaction. """ if not db or not project_name: return @@ -173,26 +173,58 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: if not proj: return file_name = Path(file_path).stem - df = db.get_data_file(proj["id"], file_name) - if not df: + + # Use a single transaction for atomicity + db.conn.execute("BEGIN IMMEDIATE") + try: + df = db.get_data_file(proj["id"], file_name) top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - df_id = db.create_data_file(proj["id"], file_name, "generic", top_level) - else: - df_id = df["id"] + if not df: + now = __import__('time').time() + cur = db.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (proj["id"], file_name, "generic", json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid + else: + df_id = df["id"] + # Update top_level metadata + now = __import__('time').time() + db.conn.execute( + "UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?", + (json.dumps(top_level), now, df_id), + ) - # Sync sequences - batch_data = data.get(KEY_BATCH_DATA, []) - if isinstance(batch_data, list): - db.delete_sequences_for_file(df_id) - for item in batch_data: - seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) - db.upsert_sequence(df_id, seq_num, item) + # Sync sequences + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + for item in batch_data: + seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0)) + now = __import__('time').time() + db.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?)", + (df_id, seq_num, json.dumps(item), now), + ) - # Sync history tree - history_tree = data.get(KEY_HISTORY_TREE) - if history_tree and isinstance(history_tree, dict): - db.save_history_tree(df_id, history_tree) + # Sync history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = __import__('time').time() + db.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + db.conn.commit() + except Exception: + db.conn.rollback() + raise except Exception as e: logger.warning(f"sync_to_db failed: {e}")