From b499eb4dfd71c3277f19250d8f58586a47a328d8 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sat, 28 Feb 2026 21:32:35 +0100 Subject: [PATCH] Fix 8 bugs from second code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HIGH: - Fix JS TypeError on empty API response: validate keys/types are arrays before using them; add HTTP status check (resp.ok) - Fix BEGIN IMMEDIATE conflict: set isolation_level=None (autocommit) on SQLite connection so explicit transactions work without implicit ones MEDIUM: - Fix import_json_file non-atomic: wrap entire operation in BEGIN/COMMIT with ROLLBACK on error — no more partial imports - Fix crash on non-dict batch_data items: skip non-dict elements - Fix comma-in-key corruption: store keys/types as JSON arrays in hidden widgets instead of comma-delimited strings (backward-compat fallback) - Fix blocking I/O in API routes: change async def to def so FastAPI auto-threads the synchronous SQLite calls LOW: - Fix missing ?. on app.graph.setDirtyCanvas in refreshDynamicOutputs Co-Authored-By: Claude Opus 4.6 --- api_routes.py | 10 ++--- db.py | 85 ++++++++++++++++++++++++------------ project_loader.py | 8 +++- tests/test_db.py | 38 ++++++++++++++++ tests/test_project_loader.py | 14 ++++++ utils.py | 4 +- web/project_dynamic.js | 48 ++++++++++++-------- 7 files changed, 155 insertions(+), 52 deletions(-) diff --git a/api_routes.py b/api_routes.py index 36e84cd..6d5b42e 100644 --- a/api_routes.py +++ b/api_routes.py @@ -35,25 +35,25 @@ def _get_db() -> ProjectDB: return _db -async def _list_projects() -> dict[str, Any]: +def _list_projects() -> dict[str, Any]: db = _get_db() projects = db.list_projects() return {"projects": [p["name"] for p in projects]} -async def _list_files(name: str) -> dict[str, Any]: +def _list_files(name: str) -> dict[str, Any]: db = _get_db() files = db.list_project_files(name) return {"files": [{"name": f["name"], "data_type": f["data_type"]} for f in files]} -async def _list_sequences(name: str, file_name: str) -> dict[str, Any]: +def _list_sequences(name: str, file_name: str) -> dict[str, Any]: db = _get_db() seqs = db.list_project_sequences(name, file_name) return {"sequences": seqs} -async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: +def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() data = db.query_sequence_data(name, file_name, seq) if data is None: @@ -61,7 +61,7 @@ async def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> d return data -async def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: +def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: db = _get_db() keys, types = db.query_sequence_keys(name, file_name, seq) return {"keys": keys, "types": types} diff --git a/db.py b/db.py index b9b17a4..000f577 100644 --- a/db.py +++ b/db.py @@ -56,12 +56,15 @@ class ProjectDB: def __init__(self, db_path: str | Path | None = None): self.db_path = Path(db_path) if db_path else DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self.conn = sqlite3.connect( + str(self.db_path), + check_same_thread=False, + isolation_level=None, # autocommit — explicit BEGIN/COMMIT only + ) self.conn.row_factory = sqlite3.Row self.conn.execute("PRAGMA journal_mode=WAL") self.conn.execute("PRAGMA foreign_keys=ON") self.conn.executescript(SCHEMA_SQL) - self.conn.commit() def close(self): self.conn.close() @@ -231,7 +234,7 @@ class ProjectDB: """Import a JSON file into the database, splitting batch_data into sequences. Safe to call repeatedly — existing data_file is updated, sequences are - replaced, and history_tree is upserted. + replaced, and history_tree is upserted. Atomic: all-or-nothing. """ json_path = Path(json_path) data, _ = load_json(json_path) @@ -239,33 +242,61 @@ class ProjectDB: top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)} - existing = self.get_data_file(project_id, file_name) - if existing: - df_id = existing["id"] - now = time.time() - self.conn.execute( - "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", - (data_type, json.dumps(top_level), now, df_id), - ) - self.conn.commit() - # Clear old sequences before re-importing - self.delete_sequences_for_file(df_id) - else: - df_id = self.create_data_file(project_id, file_name, data_type, top_level) + self.conn.execute("BEGIN IMMEDIATE") + try: + existing = self.conn.execute( + "SELECT id FROM data_files WHERE project_id = ? AND name = ?", + (project_id, file_name), + ).fetchone() - # Import sequences from batch_data - batch_data = data.get(KEY_BATCH_DATA, []) - if isinstance(batch_data, list): - for item in batch_data: - seq_num = int(item.get("sequence_number", 0)) - self.upsert_sequence(df_id, seq_num, item) + if existing: + df_id = existing["id"] + now = time.time() + self.conn.execute( + "UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?", + (data_type, json.dumps(top_level), now, df_id), + ) + self.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) + else: + now = time.time() + cur = self.conn.execute( + "INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + (project_id, file_name, data_type, json.dumps(top_level), now, now), + ) + df_id = cur.lastrowid - # Import history tree - history_tree = data.get(KEY_HISTORY_TREE) - if history_tree and isinstance(history_tree, dict): - self.save_history_tree(df_id, history_tree) + # Import sequences from batch_data + batch_data = data.get(KEY_BATCH_DATA, []) + if isinstance(batch_data, list): + for item in batch_data: + if not isinstance(item, dict): + continue + seq_num = int(item.get("sequence_number", 0)) + now = time.time() + self.conn.execute( + "INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) " + "VALUES (?, ?, ?, ?) " + "ON CONFLICT(data_file_id, sequence_number) DO UPDATE SET data=excluded.data, updated_at=excluded.updated_at", + (df_id, seq_num, json.dumps(item), now), + ) - return df_id + # Import history tree + history_tree = data.get(KEY_HISTORY_TREE) + if history_tree and isinstance(history_tree, dict): + now = time.time() + self.conn.execute( + "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " + "VALUES (?, ?, ?) " + "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", + (df_id, json.dumps(history_tree), now), + ) + + self.conn.execute("COMMIT") + return df_id + except Exception: + self.conn.execute("ROLLBACK") + raise # ------------------------------------------------------------------ # Query helpers (for REST API) diff --git a/project_loader.py b/project_loader.py index 1634135..c7a2cc9 100644 --- a/project_loader.py +++ b/project_loader.py @@ -134,7 +134,13 @@ class ProjectLoaderDynamic: output_keys="", output_types=""): data = _fetch_data(manager_url, project_name, file_name, sequence_number) - keys = [k.strip() for k in output_keys.split(",") if k.strip()] if output_keys else [] + # Parse keys — try JSON array first, fall back to comma-split for compat + keys = [] + if output_keys: + try: + keys = json.loads(output_keys) + except (json.JSONDecodeError, TypeError): + keys = [k.strip() for k in output_keys.split(",") if k.strip()] results = [] for key in keys: diff --git a/tests/test_db.py b/tests/test_db.py index a0dc7ab..341edcb 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -272,6 +272,44 @@ class TestImport: s1 = db.get_sequence(df_id_2, 1) assert s1["prompt"] == "v2" + def test_import_skips_non_dict_batch_items(self, db, tmp_path): + """Non-dict elements in batch_data should be silently skipped, not crash.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "mixed.json" + data = {KEY_BATCH_DATA: [ + {"sequence_number": 1, "prompt": "valid"}, + "not a dict", + 42, + None, + {"sequence_number": 3, "prompt": "also valid"}, + ]} + json_path.write_text(json.dumps(data)) + df_id = db.import_json_file(pid, json_path) + + seqs = db.list_sequences(df_id) + assert seqs == [1, 3] + + def test_import_atomic_on_error(self, db, tmp_path): + """If import fails partway, no partial data should be committed.""" + pid = db.create_project("p1", "/p1") + json_path = tmp_path / "batch.json" + data = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "hello"}]} + json_path.write_text(json.dumps(data)) + db.import_json_file(pid, json_path) + + # Now try to import with bad data that will cause an error + # (overwrite the file with invalid sequence_number that causes int() to fail) + bad_data = {KEY_BATCH_DATA: [{"sequence_number": "not_a_number", "prompt": "bad"}]} + json_path.write_text(json.dumps(bad_data)) + with pytest.raises(ValueError): + db.import_json_file(pid, json_path) + + # Original data should still be intact (rollback worked) + df = db.get_data_file(pid, "batch") + assert df is not None + s1 = db.get_sequence(df["id"], 1) + assert s1["prompt"] == "hello" + # ------------------------------------------------------------------ # Query helpers diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index fd77eb9..41399ca 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -86,6 +86,20 @@ class TestProjectLoaderDynamic: assert result[2] == 1.5 assert len(result) == MAX_DYNAMIC_OUTPUTS + def test_load_dynamic_with_json_encoded_keys(self): + """JSON-encoded output_keys should be parsed correctly.""" + import json as _json + data = {"my,key": "comma_val", "normal": "ok"} + node = ProjectLoaderDynamic() + keys_json = _json.dumps(["my,key", "normal"]) + with patch("project_loader._fetch_data", return_value=data): + result = node.load_dynamic( + "http://localhost:8080", "proj1", "batch_i2v", 1, + output_keys=keys_json + ) + assert result[0] == "comma_val" + assert result[1] == "ok" + def test_load_dynamic_empty_keys(self): node = ProjectLoaderDynamic() with patch("project_loader._fetch_data", return_value={"prompt": "hello"}): diff --git a/utils.py b/utils.py index 44707a2..809d58f 100644 --- a/utils.py +++ b/utils.py @@ -221,9 +221,9 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: (df_id, json.dumps(history_tree), now), ) - db.conn.commit() + db.conn.execute("COMMIT") except Exception: - db.conn.rollback() + db.conn.execute("ROLLBACK") raise except Exception as e: logger.warning(f"sync_to_db failed: {e}") diff --git a/web/project_dynamic.js b/web/project_dynamic.js index 9346f3d..ed830a8 100644 --- a/web/project_dynamic.js +++ b/web/project_dynamic.js @@ -56,20 +56,27 @@ app.registerExtension({ const resp = await api.fetchApi( `/json_manager/get_project_keys?url=${encodeURIComponent(urlWidget.value)}&project=${encodeURIComponent(projectWidget.value)}&file=${encodeURIComponent(fileWidget.value)}&seq=${seqWidget?.value || 1}` ); - const data = await resp.json(); - const { keys, types } = data; - // If the API returned an error, keep existing outputs and links intact - if (data.error) { - console.warn("[ProjectLoaderDynamic] API error, keeping existing outputs:", data.error); + if (!resp.ok) { + console.warn("[ProjectLoaderDynamic] HTTP error", resp.status, "— keeping existing outputs"); return; } - // Store keys and types in hidden widgets for persistence + const data = await resp.json(); + const keys = data.keys; + const types = data.types; + + // If the API returned an error or missing data, keep existing outputs and links intact + if (data.error || !Array.isArray(keys) || !Array.isArray(types)) { + console.warn("[ProjectLoaderDynamic] API error or missing data, keeping existing outputs:", data.error || "no keys/types"); + return; + } + + // Store keys and types in hidden widgets for persistence (JSON-encoded) const okWidget = this.widgets?.find(w => w.name === "output_keys"); - if (okWidget) okWidget.value = keys.join(","); + if (okWidget) okWidget.value = JSON.stringify(keys); const otWidget = this.widgets?.find(w => w.name === "output_types"); - if (otWidget) otWidget.value = types.join(","); + if (otWidget) otWidget.value = JSON.stringify(types); // Build a map of current output names to slot indices const oldSlots = {}; @@ -116,7 +123,7 @@ app.registerExtension({ } this.setSize(this.computeSize()); - app.graph.setDirtyCanvas(true, true); + app.graph?.setDirtyCanvas(true, true); } catch (e) { console.error("[ProjectLoaderDynamic] Refresh failed:", e); } @@ -137,12 +144,19 @@ app.registerExtension({ const okWidget = this.widgets?.find(w => w.name === "output_keys"); const otWidget = this.widgets?.find(w => w.name === "output_types"); - const keys = okWidget?.value - ? okWidget.value.split(",").filter(k => k.trim()) - : []; - const types = otWidget?.value - ? otWidget.value.split(",") - : []; + // Parse keys/types — try JSON array first, fall back to comma-split + let keys = []; + if (okWidget?.value) { + try { keys = JSON.parse(okWidget.value); } catch (_) { + keys = okWidget.value.split(",").map(k => k.trim()).filter(Boolean); + } + } + let types = []; + if (otWidget?.value) { + try { types = JSON.parse(otWidget.value); } catch (_) { + types = otWidget.value.split(","); + } + } if (keys.length > 0) { // On load, LiteGraph already restored serialized outputs with links. @@ -159,8 +173,8 @@ app.registerExtension({ } else if (this.outputs.length > 0) { // Widget values empty but serialized outputs exist — sync widgets // from the outputs LiteGraph already restored (fallback). - if (okWidget) okWidget.value = this.outputs.map(o => o.name).join(","); - if (otWidget) otWidget.value = this.outputs.map(o => o.type).join(","); + if (okWidget) okWidget.value = JSON.stringify(this.outputs.map(o => o.name)); + if (otWidget) otWidget.value = JSON.stringify(this.outputs.map(o => o.type)); } this.setSize(this.computeSize());