Fix 6 bugs found during code review
- Fix NameError: pass state to _render_vace_settings (tab_batch_ng.py) - Fix non-atomic sync_to_db: use BEGIN IMMEDIATE transaction with rollback - Fix create_secondary() missing db/current_project/db_enabled fields - Fix URL encoding: percent-encode project/file names in API URLs - Fix import_json_file crash on re-import: upsert instead of insert - Fix dual DB instances: share single ProjectDB between UI and API routes - Also fixes top_level metadata never being updated on existing data_files Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
21
db.py
21
db.py
@@ -228,15 +228,30 @@ class ProjectDB:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
||||
"""Import a JSON file into the database, splitting batch_data into sequences."""
|
||||
"""Import a JSON file into the database, splitting batch_data into sequences.
|
||||
|
||||
Safe to call repeatedly — existing data_file is updated, sequences are
|
||||
replaced, and history_tree is upserted.
|
||||
"""
|
||||
json_path = Path(json_path)
|
||||
data, _ = load_json(json_path)
|
||||
file_name = json_path.stem
|
||||
|
||||
# Extract top-level keys that aren't batch_data or history_tree
|
||||
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||
|
||||
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
|
||||
existing = self.get_data_file(project_id, file_name)
|
||||
if existing:
|
||||
df_id = existing["id"]
|
||||
now = time.time()
|
||||
self.conn.execute(
|
||||
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
||||
(data_type, json.dumps(top_level), now, df_id),
|
||||
)
|
||||
self.conn.commit()
|
||||
# Clear old sequences before re-importing
|
||||
self.delete_sequences_for_file(df_id)
|
||||
else:
|
||||
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
|
||||
|
||||
# Import sequences from batch_data
|
||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||
|
||||
24
main.py
24
main.py
@@ -21,6 +21,13 @@ from api_routes import register_api_routes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Single shared DB instance for both the UI and API routes
|
||||
_shared_db: ProjectDB | None = None
|
||||
try:
|
||||
_shared_db = ProjectDB()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
||||
|
||||
|
||||
@ui.page('/')
|
||||
def index():
|
||||
@@ -166,12 +173,8 @@ def index():
|
||||
current_project=config.get('current_project', ''),
|
||||
)
|
||||
|
||||
# Initialize project database
|
||||
try:
|
||||
state.db = ProjectDB()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
||||
state.db = None
|
||||
# Use the shared DB instance
|
||||
state.db = _shared_db
|
||||
|
||||
dual_pane = {'active': False, 'state': None}
|
||||
|
||||
@@ -500,11 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
||||
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
||||
|
||||
|
||||
# Register REST API routes for ComfyUI connectivity
|
||||
try:
|
||||
_api_db = ProjectDB()
|
||||
register_api_routes(_api_db)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register API routes: {e}")
|
||||
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
||||
if _shared_db is not None:
|
||||
register_api_routes(_shared_db)
|
||||
|
||||
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Any
|
||||
@@ -49,13 +50,17 @@ def _fetch_json(url: str) -> dict:
|
||||
|
||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||
"""Fetch sequence data from the NiceGUI REST API."""
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}"
|
||||
p = urllib.parse.quote(project, safe='')
|
||||
f = urllib.parse.quote(file, safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
|
||||
return _fetch_json(url)
|
||||
|
||||
|
||||
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||
"""Fetch keys/types from the NiceGUI REST API."""
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}"
|
||||
p = urllib.parse.quote(project, safe='')
|
||||
f = urllib.parse.quote(file, safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
|
||||
return _fetch_json(url)
|
||||
|
||||
|
||||
@@ -71,7 +76,7 @@ if PromptServer is not None:
|
||||
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
||||
async def list_project_files_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
project = request.query.get("project", "")
|
||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
||||
data = _fetch_json(url)
|
||||
return web.json_response(data)
|
||||
@@ -79,8 +84,8 @@ if PromptServer is not None:
|
||||
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
||||
async def list_project_sequences_proxy(request):
|
||||
manager_url = request.query.get("url", "http://localhost:8080")
|
||||
project = request.query.get("project", "")
|
||||
file_name = request.query.get("file", "")
|
||||
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
||||
data = _fetch_json(url)
|
||||
return web.json_response(data)
|
||||
@@ -98,6 +103,7 @@ if PromptServer is not None:
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 0. DYNAMIC NODE (Project-based)
|
||||
# ==========================================
|
||||
|
||||
3
state.py
3
state.py
@@ -34,4 +34,7 @@ class AppState:
|
||||
config=self.config,
|
||||
current_dir=self.current_dir,
|
||||
snippets=self.snippets,
|
||||
db=self.db,
|
||||
current_project=self.current_project,
|
||||
db_enabled=self.db_enabled,
|
||||
)
|
||||
|
||||
@@ -457,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
|
||||
# --- VACE Settings (full width) ---
|
||||
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
||||
_render_vace_settings(i, seq, batch_list, data, file_path, refresh_list)
|
||||
_render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
|
||||
|
||||
# --- LoRA Settings ---
|
||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
||||
@@ -539,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
# VACE Settings sub-section
|
||||
# ======================================================================
|
||||
|
||||
def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
||||
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
||||
# VACE Schedule (needed early for both columns)
|
||||
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
||||
|
||||
|
||||
@@ -246,6 +246,32 @@ class TestImport:
|
||||
seqs = db.list_sequences(df_id)
|
||||
assert seqs == []
|
||||
|
||||
def test_reimport_updates_existing(self, db, tmp_path):
|
||||
"""Re-importing the same file should update data, not crash."""
|
||||
pid = db.create_project("p1", "/p1")
|
||||
json_path = tmp_path / "batch.json"
|
||||
|
||||
# First import
|
||||
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
|
||||
json_path.write_text(json.dumps(data_v1))
|
||||
df_id_1 = db.import_json_file(pid, json_path, "i2v")
|
||||
|
||||
# Second import (same file, updated data)
|
||||
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
|
||||
json_path.write_text(json.dumps(data_v2))
|
||||
df_id_2 = db.import_json_file(pid, json_path, "vace")
|
||||
|
||||
# Should reuse the same data_file row
|
||||
assert df_id_1 == df_id_2
|
||||
# Data type should be updated
|
||||
df = db.get_data_file(pid, "batch")
|
||||
assert df["data_type"] == "vace"
|
||||
# Sequences should reflect v2
|
||||
seqs = db.list_sequences(df_id_2)
|
||||
assert seqs == [1, 2]
|
||||
s1 = db.get_sequence(df_id_2, 1)
|
||||
assert s1["prompt"] == "v2"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query helpers
|
||||
|
||||
@@ -61,6 +61,16 @@ class TestFetchHelpers:
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "//api" not in called_url
|
||||
|
||||
def test_fetch_data_encodes_special_chars(self):
|
||||
"""Project/file names with spaces or special chars should be percent-encoded."""
|
||||
data = {"prompt": "hello"}
|
||||
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
|
||||
called_url = mock.call_args[0][0]
|
||||
assert "my%20project" in called_url
|
||||
assert "batch%20file" in called_url
|
||||
assert " " not in called_url.split("?")[0] # no raw spaces in path
|
||||
|
||||
|
||||
class TestProjectLoaderDynamic:
|
||||
def test_load_dynamic_with_keys(self):
|
||||
|
||||
66
utils.py
66
utils.py
@@ -164,7 +164,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||
"""Dual-write helper: sync JSON data to the project database.
|
||||
|
||||
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
||||
and saves the history_tree.
|
||||
and saves the history_tree. All writes happen in a single transaction.
|
||||
"""
|
||||
if not db or not project_name:
|
||||
return
|
||||
@@ -173,26 +173,58 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||
if not proj:
|
||||
return
|
||||
file_name = Path(file_path).stem
|
||||
df = db.get_data_file(proj["id"], file_name)
|
||||
if not df:
|
||||
|
||||
# Use a single transaction for atomicity
|
||||
db.conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
df = db.get_data_file(proj["id"], file_name)
|
||||
top_level = {k: v for k, v in data.items()
|
||||
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||
df_id = db.create_data_file(proj["id"], file_name, "generic", top_level)
|
||||
else:
|
||||
df_id = df["id"]
|
||||
if not df:
|
||||
now = __import__('time').time()
|
||||
cur = db.conn.execute(
|
||||
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
|
||||
)
|
||||
df_id = cur.lastrowid
|
||||
else:
|
||||
df_id = df["id"]
|
||||
# Update top_level metadata
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
||||
(json.dumps(top_level), now, df_id),
|
||||
)
|
||||
|
||||
# Sync sequences
|
||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||
if isinstance(batch_data, list):
|
||||
db.delete_sequences_for_file(df_id)
|
||||
for item in batch_data:
|
||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||
db.upsert_sequence(df_id, seq_num, item)
|
||||
# Sync sequences
|
||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||
if isinstance(batch_data, list):
|
||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||
for item in batch_data:
|
||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?)",
|
||||
(df_id, seq_num, json.dumps(item), now),
|
||||
)
|
||||
|
||||
# Sync history tree
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
db.save_history_tree(df_id, history_tree)
|
||||
# Sync history tree
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
now = __import__('time').time()
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(df_id, json.dumps(history_tree), now),
|
||||
)
|
||||
|
||||
db.conn.commit()
|
||||
except Exception:
|
||||
db.conn.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"sync_to_db failed: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user