Fix 6 bugs found during code review
- Fix NameError: pass state to _render_vace_settings (tab_batch_ng.py) - Fix non-atomic sync_to_db: use BEGIN IMMEDIATE transaction with rollback - Fix create_secondary() missing db/current_project/db_enabled fields - Fix URL encoding: percent-encode project/file names in API URLs - Fix import_json_file crash on re-import: upsert instead of insert - Fix dual DB instances: share single ProjectDB between UI and API routes - Also fixes top_level metadata never being updated on existing data_files Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
19
db.py
19
db.py
@@ -228,14 +228,29 @@ class ProjectDB:
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
def import_json_file(self, project_id: int, json_path: str | Path, data_type: str = "generic") -> int:
|
||||||
"""Import a JSON file into the database, splitting batch_data into sequences."""
|
"""Import a JSON file into the database, splitting batch_data into sequences.
|
||||||
|
|
||||||
|
Safe to call repeatedly — existing data_file is updated, sequences are
|
||||||
|
replaced, and history_tree is upserted.
|
||||||
|
"""
|
||||||
json_path = Path(json_path)
|
json_path = Path(json_path)
|
||||||
data, _ = load_json(json_path)
|
data, _ = load_json(json_path)
|
||||||
file_name = json_path.stem
|
file_name = json_path.stem
|
||||||
|
|
||||||
# Extract top-level keys that aren't batch_data or history_tree
|
|
||||||
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
top_level = {k: v for k, v in data.items() if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||||
|
|
||||||
|
existing = self.get_data_file(project_id, file_name)
|
||||||
|
if existing:
|
||||||
|
df_id = existing["id"]
|
||||||
|
now = time.time()
|
||||||
|
self.conn.execute(
|
||||||
|
"UPDATE data_files SET data_type = ?, top_level = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(data_type, json.dumps(top_level), now, df_id),
|
||||||
|
)
|
||||||
|
self.conn.commit()
|
||||||
|
# Clear old sequences before re-importing
|
||||||
|
self.delete_sequences_for_file(df_id)
|
||||||
|
else:
|
||||||
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
|
df_id = self.create_data_file(project_id, file_name, data_type, top_level)
|
||||||
|
|
||||||
# Import sequences from batch_data
|
# Import sequences from batch_data
|
||||||
|
|||||||
24
main.py
24
main.py
@@ -21,6 +21,13 @@ from api_routes import register_api_routes
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Single shared DB instance for both the UI and API routes
|
||||||
|
_shared_db: ProjectDB | None = None
|
||||||
|
try:
|
||||||
|
_shared_db = ProjectDB()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
||||||
|
|
||||||
|
|
||||||
@ui.page('/')
|
@ui.page('/')
|
||||||
def index():
|
def index():
|
||||||
@@ -166,12 +173,8 @@ def index():
|
|||||||
current_project=config.get('current_project', ''),
|
current_project=config.get('current_project', ''),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize project database
|
# Use the shared DB instance
|
||||||
try:
|
state.db = _shared_db
|
||||||
state.db = ProjectDB()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to initialize ProjectDB: {e}")
|
|
||||||
state.db = None
|
|
||||||
|
|
||||||
dual_pane = {'active': False, 'state': None}
|
dual_pane = {'active': False, 'state': None}
|
||||||
|
|
||||||
@@ -500,11 +503,8 @@ def render_sidebar(state: AppState, dual_pane: dict):
|
|||||||
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
ui.checkbox('Show Comfy Monitor', value=True, on_change=on_monitor_toggle)
|
||||||
|
|
||||||
|
|
||||||
# Register REST API routes for ComfyUI connectivity
|
# Register REST API routes for ComfyUI connectivity (uses the shared DB instance)
|
||||||
try:
|
if _shared_db is not None:
|
||||||
_api_db = ProjectDB()
|
register_api_routes(_shared_db)
|
||||||
register_api_routes(_api_db)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to register API routes: {e}")
|
|
||||||
|
|
||||||
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
ui.run(title='AI Settings Manager', port=8080, reload=True)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import urllib.parse
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib.error
|
import urllib.error
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -49,13 +50,17 @@ def _fetch_json(url: str) -> dict:
|
|||||||
|
|
||||||
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||||
"""Fetch sequence data from the NiceGUI REST API."""
|
"""Fetch sequence data from the NiceGUI REST API."""
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/data?seq={seq}"
|
p = urllib.parse.quote(project, safe='')
|
||||||
|
f = urllib.parse.quote(file, safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/data?seq={seq}"
|
||||||
return _fetch_json(url)
|
return _fetch_json(url)
|
||||||
|
|
||||||
|
|
||||||
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
def _fetch_keys(manager_url: str, project: str, file: str, seq: int) -> dict:
|
||||||
"""Fetch keys/types from the NiceGUI REST API."""
|
"""Fetch keys/types from the NiceGUI REST API."""
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file}/keys?seq={seq}"
|
p = urllib.parse.quote(project, safe='')
|
||||||
|
f = urllib.parse.quote(file, safe='')
|
||||||
|
url = f"{manager_url.rstrip('/')}/api/projects/{p}/files/{f}/keys?seq={seq}"
|
||||||
return _fetch_json(url)
|
return _fetch_json(url)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +76,7 @@ if PromptServer is not None:
|
|||||||
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
@PromptServer.instance.routes.get("/json_manager/list_project_files")
|
||||||
async def list_project_files_proxy(request):
|
async def list_project_files_proxy(request):
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
project = request.query.get("project", "")
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files"
|
||||||
data = _fetch_json(url)
|
data = _fetch_json(url)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
@@ -79,8 +84,8 @@ if PromptServer is not None:
|
|||||||
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
@PromptServer.instance.routes.get("/json_manager/list_project_sequences")
|
||||||
async def list_project_sequences_proxy(request):
|
async def list_project_sequences_proxy(request):
|
||||||
manager_url = request.query.get("url", "http://localhost:8080")
|
manager_url = request.query.get("url", "http://localhost:8080")
|
||||||
project = request.query.get("project", "")
|
project = urllib.parse.quote(request.query.get("project", ""), safe='')
|
||||||
file_name = request.query.get("file", "")
|
file_name = urllib.parse.quote(request.query.get("file", ""), safe='')
|
||||||
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
url = f"{manager_url.rstrip('/')}/api/projects/{project}/files/{file_name}/sequences"
|
||||||
data = _fetch_json(url)
|
data = _fetch_json(url)
|
||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
@@ -98,6 +103,7 @@ if PromptServer is not None:
|
|||||||
return web.json_response(data)
|
return web.json_response(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# 0. DYNAMIC NODE (Project-based)
|
# 0. DYNAMIC NODE (Project-based)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
|
|||||||
3
state.py
3
state.py
@@ -34,4 +34,7 @@ class AppState:
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
current_dir=self.current_dir,
|
current_dir=self.current_dir,
|
||||||
snippets=self.snippets,
|
snippets=self.snippets,
|
||||||
|
db=self.db,
|
||||||
|
current_project=self.current_project,
|
||||||
|
db_enabled=self.db_enabled,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
|
|
||||||
# --- VACE Settings (full width) ---
|
# --- VACE Settings (full width) ---
|
||||||
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
with ui.expansion('VACE Settings', icon='settings').classes('w-full'):
|
||||||
_render_vace_settings(i, seq, batch_list, data, file_path, refresh_list)
|
_render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list)
|
||||||
|
|
||||||
# --- LoRA Settings ---
|
# --- LoRA Settings ---
|
||||||
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
with ui.expansion('LoRA Settings', icon='style').classes('w-full'):
|
||||||
@@ -539,7 +539,7 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
|||||||
# VACE Settings sub-section
|
# VACE Settings sub-section
|
||||||
# ======================================================================
|
# ======================================================================
|
||||||
|
|
||||||
def _render_vace_settings(i, seq, batch_list, data, file_path, refresh_list):
|
def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_list):
|
||||||
# VACE Schedule (needed early for both columns)
|
# VACE Schedule (needed early for both columns)
|
||||||
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
sched_val = max(0, min(int(seq.get('vace schedule', 1)), len(VACE_MODES) - 1))
|
||||||
|
|
||||||
|
|||||||
@@ -246,6 +246,32 @@ class TestImport:
|
|||||||
seqs = db.list_sequences(df_id)
|
seqs = db.list_sequences(df_id)
|
||||||
assert seqs == []
|
assert seqs == []
|
||||||
|
|
||||||
|
def test_reimport_updates_existing(self, db, tmp_path):
|
||||||
|
"""Re-importing the same file should update data, not crash."""
|
||||||
|
pid = db.create_project("p1", "/p1")
|
||||||
|
json_path = tmp_path / "batch.json"
|
||||||
|
|
||||||
|
# First import
|
||||||
|
data_v1 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v1"}]}
|
||||||
|
json_path.write_text(json.dumps(data_v1))
|
||||||
|
df_id_1 = db.import_json_file(pid, json_path, "i2v")
|
||||||
|
|
||||||
|
# Second import (same file, updated data)
|
||||||
|
data_v2 = {KEY_BATCH_DATA: [{"sequence_number": 1, "prompt": "v2"}, {"sequence_number": 2, "prompt": "new"}]}
|
||||||
|
json_path.write_text(json.dumps(data_v2))
|
||||||
|
df_id_2 = db.import_json_file(pid, json_path, "vace")
|
||||||
|
|
||||||
|
# Should reuse the same data_file row
|
||||||
|
assert df_id_1 == df_id_2
|
||||||
|
# Data type should be updated
|
||||||
|
df = db.get_data_file(pid, "batch")
|
||||||
|
assert df["data_type"] == "vace"
|
||||||
|
# Sequences should reflect v2
|
||||||
|
seqs = db.list_sequences(df_id_2)
|
||||||
|
assert seqs == [1, 2]
|
||||||
|
s1 = db.get_sequence(df_id_2, 1)
|
||||||
|
assert s1["prompt"] == "v2"
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Query helpers
|
# Query helpers
|
||||||
|
|||||||
@@ -61,6 +61,16 @@ class TestFetchHelpers:
|
|||||||
called_url = mock.call_args[0][0]
|
called_url = mock.call_args[0][0]
|
||||||
assert "//api" not in called_url
|
assert "//api" not in called_url
|
||||||
|
|
||||||
|
def test_fetch_data_encodes_special_chars(self):
|
||||||
|
"""Project/file names with spaces or special chars should be percent-encoded."""
|
||||||
|
data = {"prompt": "hello"}
|
||||||
|
with patch("project_loader.urllib.request.urlopen", return_value=_mock_urlopen(data)) as mock:
|
||||||
|
_fetch_data("http://localhost:8080", "my project", "batch file", 1)
|
||||||
|
called_url = mock.call_args[0][0]
|
||||||
|
assert "my%20project" in called_url
|
||||||
|
assert "batch%20file" in called_url
|
||||||
|
assert " " not in called_url.split("?")[0] # no raw spaces in path
|
||||||
|
|
||||||
|
|
||||||
class TestProjectLoaderDynamic:
|
class TestProjectLoaderDynamic:
|
||||||
def test_load_dynamic_with_keys(self):
|
def test_load_dynamic_with_keys(self):
|
||||||
|
|||||||
44
utils.py
44
utils.py
@@ -164,7 +164,7 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
"""Dual-write helper: sync JSON data to the project database.
|
"""Dual-write helper: sync JSON data to the project database.
|
||||||
|
|
||||||
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
Resolves (or creates) the data_file, upserts all sequences from batch_data,
|
||||||
and saves the history_tree.
|
and saves the history_tree. All writes happen in a single transaction.
|
||||||
"""
|
"""
|
||||||
if not db or not project_name:
|
if not db or not project_name:
|
||||||
return
|
return
|
||||||
@@ -173,26 +173,58 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
|||||||
if not proj:
|
if not proj:
|
||||||
return
|
return
|
||||||
file_name = Path(file_path).stem
|
file_name = Path(file_path).stem
|
||||||
|
|
||||||
|
# Use a single transaction for atomicity
|
||||||
|
db.conn.execute("BEGIN IMMEDIATE")
|
||||||
|
try:
|
||||||
df = db.get_data_file(proj["id"], file_name)
|
df = db.get_data_file(proj["id"], file_name)
|
||||||
if not df:
|
|
||||||
top_level = {k: v for k, v in data.items()
|
top_level = {k: v for k, v in data.items()
|
||||||
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
if k not in (KEY_BATCH_DATA, KEY_HISTORY_TREE)}
|
||||||
df_id = db.create_data_file(proj["id"], file_name, "generic", top_level)
|
if not df:
|
||||||
|
now = __import__('time').time()
|
||||||
|
cur = db.conn.execute(
|
||||||
|
"INSERT INTO data_files (project_id, name, data_type, top_level, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(proj["id"], file_name, "generic", json.dumps(top_level), now, now),
|
||||||
|
)
|
||||||
|
df_id = cur.lastrowid
|
||||||
else:
|
else:
|
||||||
df_id = df["id"]
|
df_id = df["id"]
|
||||||
|
# Update top_level metadata
|
||||||
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"UPDATE data_files SET top_level = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(json.dumps(top_level), now, df_id),
|
||||||
|
)
|
||||||
|
|
||||||
# Sync sequences
|
# Sync sequences
|
||||||
batch_data = data.get(KEY_BATCH_DATA, [])
|
batch_data = data.get(KEY_BATCH_DATA, [])
|
||||||
if isinstance(batch_data, list):
|
if isinstance(batch_data, list):
|
||||||
db.delete_sequences_for_file(df_id)
|
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||||
for item in batch_data:
|
for item in batch_data:
|
||||||
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
seq_num = int(item.get(KEY_SEQUENCE_NUMBER, 0))
|
||||||
db.upsert_sequence(df_id, seq_num, item)
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"INSERT INTO sequences (data_file_id, sequence_number, data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?)",
|
||||||
|
(df_id, seq_num, json.dumps(item), now),
|
||||||
|
)
|
||||||
|
|
||||||
# Sync history tree
|
# Sync history tree
|
||||||
history_tree = data.get(KEY_HISTORY_TREE)
|
history_tree = data.get(KEY_HISTORY_TREE)
|
||||||
if history_tree and isinstance(history_tree, dict):
|
if history_tree and isinstance(history_tree, dict):
|
||||||
db.save_history_tree(df_id, history_tree)
|
now = __import__('time').time()
|
||||||
|
db.conn.execute(
|
||||||
|
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||||
|
"VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||||
|
(df_id, json.dumps(history_tree), now),
|
||||||
|
)
|
||||||
|
|
||||||
|
db.conn.commit()
|
||||||
|
except Exception:
|
||||||
|
db.conn.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"sync_to_db failed: {e}")
|
logger.warning(f"sync_to_db failed: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user