fix: read sequence data directly from JSON file in API endpoints
_get_data and _get_keys were querying the SQLite DB which only gets populated when db_enabled is on. JSON file is always the source of truth, so read from it directly — fixes missing keys (e.g. resolutions) when DB hasn't been synced. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+35
-18
@@ -1,16 +1,18 @@
|
|||||||
"""REST API endpoints for ComfyUI to query project data from SQLite.
|
"""REST API endpoints for ComfyUI to query project data from JSON files.
|
||||||
|
|
||||||
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
All endpoints are read-only. Mounted on the NiceGUI/FastAPI server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Query
|
from fastapi import HTTPException, Query
|
||||||
from nicegui import app
|
from nicegui import app
|
||||||
|
|
||||||
from db import ProjectDB
|
from db import ProjectDB
|
||||||
|
from utils import load_json, KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -54,34 +56,49 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]:
|
|||||||
return {"sequences": seqs}
|
return {"sequences": seqs}
|
||||||
|
|
||||||
|
|
||||||
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _load_sequences(name: str, file_name: str) -> list[dict]:
|
||||||
t0 = time.perf_counter()
|
"""Load the batch_data list directly from the JSON file."""
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
proj = db.get_project(name)
|
proj = db.get_project(name)
|
||||||
if not proj:
|
if not proj:
|
||||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
||||||
df = db.get_data_file_by_names(name, file_name)
|
json_path = Path(proj["folder_path"]) / f"{file_name}.json"
|
||||||
if not df:
|
if not json_path.exists():
|
||||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
||||||
data = db.get_sequence(df["id"], seq)
|
data, _ = load_json(json_path)
|
||||||
if data is None:
|
return data.get(KEY_BATCH_DATA, [])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
sequences = _load_sequences(name, file_name)
|
||||||
|
match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None)
|
||||||
|
if match is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||||
logger.info("API _get_data %s/%s seq=%d (%d keys): %.3fs",
|
logger.info("API _get_data %s/%s seq=%d (%d keys): %.3fs",
|
||||||
name, file_name, seq, len(data), time.perf_counter() - t0)
|
name, file_name, seq, len(match), time.perf_counter() - t0)
|
||||||
return data
|
return match
|
||||||
|
|
||||||
|
|
||||||
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]:
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
db = _get_db()
|
sequences = _load_sequences(name, file_name)
|
||||||
proj = db.get_project(name)
|
match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None)
|
||||||
if not proj:
|
if match is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Project '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Sequence {seq} not found")
|
||||||
df = db.get_data_file_by_names(name, file_name)
|
keys = [k for k in match.keys() if k != KEY_SEQUENCE_NUMBER]
|
||||||
if not df:
|
types = []
|
||||||
raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'")
|
for k in keys:
|
||||||
keys, types = db.get_sequence_keys(df["id"], seq)
|
v = match[k]
|
||||||
total = db.count_sequences(df["id"])
|
if isinstance(v, bool):
|
||||||
|
types.append("BOOLEAN")
|
||||||
|
elif isinstance(v, int):
|
||||||
|
types.append("INT")
|
||||||
|
elif isinstance(v, float):
|
||||||
|
types.append("FLOAT")
|
||||||
|
else:
|
||||||
|
types.append("STRING")
|
||||||
|
total = len(sequences)
|
||||||
logger.info("API _get_keys %s/%s seq=%d (%d keys): %.3fs",
|
logger.info("API _get_keys %s/%s seq=%d (%d keys): %.3fs",
|
||||||
name, file_name, seq, len(keys), time.perf_counter() - t0)
|
name, file_name, seq, len(keys), time.perf_counter() - t0)
|
||||||
return {"keys": keys, "types": types, "total_sequences": total}
|
return {"keys": keys, "types": types, "total_sequences": total}
|
||||||
|
|||||||
Reference in New Issue
Block a user