diff --git a/api_routes.py b/api_routes.py index 34f61ff..123dad8 100644 --- a/api_routes.py +++ b/api_routes.py @@ -1,16 +1,19 @@ -"""REST API endpoints for ComfyUI to query project data from SQLite. +"""REST API endpoints for ComfyUI to query project data from JSON files. All endpoints are read-only. Mounted on the NiceGUI/FastAPI server. """ import logging import time +from pathlib import Path from typing import Any from fastapi import HTTPException, Query +from fastapi.responses import FileResponse from nicegui import app from db import ProjectDB +from utils import load_json, load_config, KEY_BATCH_DATA, KEY_SEQUENCE_NUMBER logger = logging.getLogger(__name__) @@ -24,10 +27,13 @@ def register_api_routes(db: ProjectDB) -> None: _db = db app.add_api_route("/api/projects", _list_projects, methods=["GET"]) + app.add_api_route("/api/active-project", _get_active_project, methods=["GET"]) + app.add_api_route("/api/projects/{name}", _get_project, methods=["GET"]) app.add_api_route("/api/projects/{name}/files", _list_files, methods=["GET"]) app.add_api_route("/api/projects/{name}/files/{file_name}/sequences", _list_sequences, methods=["GET"]) app.add_api_route("/api/projects/{name}/files/{file_name}/data", _get_data, methods=["GET"]) app.add_api_route("/api/projects/{name}/files/{file_name}/keys", _get_keys, methods=["GET"]) + app.add_api_route("/api/image-preview", _serve_image, methods=["GET"]) def _get_db() -> ProjectDB: @@ -42,6 +48,20 @@ def _list_projects() -> dict[str, Any]: return {"projects": [p["name"] for p in projects]} +def _get_active_project() -> dict[str, Any]: + config = load_config() + return {"project": config.get("current_project", "")} + + +def _get_project(name: str) -> dict[str, Any]: + db = _get_db() + proj = db.get_project(name) + if not proj: + raise HTTPException(status_code=404, detail=f"Project '{name}' not found") + return {"name": proj["name"], "folder_path": proj["folder_path"], + "description": proj.get("description", "")} + + def _list_files(name: str) -> dict[str, Any]: db = _get_db() files = db.list_project_files(name) @@ -54,34 +74,73 @@ def _list_sequences(name: str, file_name: str) -> dict[str, Any]: return {"sequences": seqs} -def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: - t0 = time.perf_counter() +def _load_sequences(name: str, file_name: str) -> list[dict]: + """Load the batch_data list directly from the JSON file.""" db = _get_db() proj = db.get_project(name) if not proj: raise HTTPException(status_code=404, detail=f"Project '{name}' not found") - df = db.get_data_file_by_names(name, file_name) - if not df: + json_path = Path(proj["folder_path"]) / f"{file_name}.json" + if not json_path.exists(): raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") - data = db.get_sequence(df["id"], seq) - if data is None: + data, _ = load_json(json_path) + return data.get(KEY_BATCH_DATA, []) + + +def _get_data(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: + t0 = time.perf_counter() + sequences = _load_sequences(name, file_name) + match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None) + if match is None: raise HTTPException(status_code=404, detail=f"Sequence {seq} not found") + result = dict(match) + for out_key, src_key in ( + ("start_name", "start frame path"), + ("middle_name", "middle frame path"), + ("end_name", "end frame path"), + ): + path_val = result.get(src_key, "") + result[out_key] = Path(path_val).stem if path_val else "" logger.info("API _get_data %s/%s seq=%d (%d keys): %.3fs", - name, file_name, seq, len(data), time.perf_counter() - t0) - return data + name, file_name, seq, len(result), time.perf_counter() - t0) + return result def _get_keys(name: str, file_name: str, seq: int = Query(default=1)) -> dict[str, Any]: t0 = time.perf_counter() - db = _get_db() - proj = db.get_project(name) - if not proj: - raise HTTPException(status_code=404, detail=f"Project '{name}' not found") - df = db.get_data_file_by_names(name, file_name) - if not df: - raise HTTPException(status_code=404, detail=f"File '{file_name}' not found in project '{name}'") - keys, types = db.get_sequence_keys(df["id"], seq) - total = db.count_sequences(df["id"]) + sequences = _load_sequences(name, file_name) + match = next((s for s in sequences if int(s.get(KEY_SEQUENCE_NUMBER, 0)) == seq), None) + if match is None: + raise HTTPException(status_code=404, detail=f"Sequence {seq} not found") + keys = [k for k in match.keys() if k != KEY_SEQUENCE_NUMBER] + types = [] + for k in keys: + v = match[k] + if isinstance(v, bool): + types.append("BOOLEAN") + elif isinstance(v, int): + types.append("INT") + elif isinstance(v, float): + types.append("FLOAT") + else: + types.append("STRING") + # Computed keys derived from frame paths + for out_key, src_key in ( + ("start_name", "start frame path"), + ("middle_name", "middle frame path"), + ("end_name", "end frame path"), + ): + if src_key in match: + keys.append(out_key) + types.append("STRING") + total = len(sequences) logger.info("API _get_keys %s/%s seq=%d (%d keys): %.3fs", name, file_name, seq, len(keys), time.perf_counter() - t0) return {"keys": keys, "types": types, "total_sequences": total} + + +def _serve_image(path: str = Query(...)) -> FileResponse: + p = Path(path) + if not p.exists() or not p.is_file(): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(str(p)) diff --git a/db.py b/db.py index 81e9461..377c367 100644 --- a/db.py +++ b/db.py @@ -9,7 +9,7 @@ from utils import load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE logger = logging.getLogger(__name__) -DEFAULT_DB_PATH = Path.home() / ".comfyui_json_manager" / "projects.db" +DEFAULT_DB_PATH = Path(__file__).parent / "projects.db" SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS projects ( diff --git a/docs/plans/2026-04-04-binary-index-decoder-design.md b/docs/plans/2026-04-04-binary-index-decoder-design.md new file mode 100644 index 0000000..dbb286a --- /dev/null +++ b/docs/plans/2026-04-04-binary-index-decoder-design.md @@ -0,0 +1,67 @@ +# BinaryIndexDecoder Node — Design + +## Summary + +A standalone ComfyUI utility node that converts an integer index into 3 boolean +outputs using binary (bit-field) encoding. Intended for use with loop counters to +gate multiple processing branches simultaneously. + +## Node Spec + +| Field | Value | +|---|---| +| Class name | `BinaryIndexDecoder` | +| Display name | `Binary Index Decoder` | +| Category | `JSON Manager/utils` | +| Function | `decode` | + +### Inputs + +| Name | Type | Default | Range | +|---|---|---|---| +| `index` | INT | 0 | 0–7 | + +### Outputs + +| Name | Type | +|---|---| +| `flag_0` | BOOLEAN | +| `flag_1` | BOOLEAN | +| `flag_2` | BOOLEAN | + +### Logic + +``` +flag_0 = bool((index >> 0) & 1) +flag_1 = bool((index >> 1) & 1) +flag_2 = bool((index >> 2) & 1) +``` + +### Truth table + +| index | flag_0 | flag_1 | flag_2 | +|---|---|---|---| +| 0 | F | F | F | +| 1 | T | F | F | +| 2 | F | T | F | +| 3 | T | T | F | +| 4 | F | F | T | +| 5 | T | F | T | +| 6 | F | T | T | +| 7 | T | T | T | + +## Implementation Notes + +- Lives in `project_loader.py` alongside other project nodes +- Added to `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS` +- No JavaScript extension needed (no source sync, no dynamic widgets) +- No NiceGUI UI changes needed +- `IS_CHANGED` not needed (output is deterministic from input) + +## Testing + +9 tests in `tests/test_project_loader.py::TestBinaryIndexDecoder`: +- Input types include `index` as INT +- All 8 index values (0–7) produce correct boolean tuple +- Out-of-range index (e.g. 8) clamps to 0–7 or wraps gracefully +- `NodeMappings` test updated: 5 nodes, mappings length == 5 diff --git a/docs/plans/2026-04-04-binary-index-decoder-plan.md b/docs/plans/2026-04-04-binary-index-decoder-plan.md new file mode 100644 index 0000000..deac578 --- /dev/null +++ b/docs/plans/2026-04-04-binary-index-decoder-plan.md @@ -0,0 +1,166 @@ +# BinaryIndexDecoder Node — Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add a standalone ComfyUI node `BinaryIndexDecoder` that converts an integer index to 3 boolean outputs using binary (bit-field) encoding. + +**Architecture:** Single class in `project_loader.py`, no JS extension needed, no NiceGUI changes. Takes `index` INT, returns `(flag_0, flag_1, flag_2)` as BOOLEAN using bit-shift logic. Added to existing node mappings. + +**Tech Stack:** Python, ComfyUI node API, pytest + +--- + +### Task 1: Write failing tests for BinaryIndexDecoder + +**Files:** +- Modify: `tests/test_project_loader.py` (append new test class at end, before `TestNodeMappings`) +- Modify: `tests/test_project_loader.py` — update `TestNodeMappings.test_mappings_exist` to expect 5 nodes + +**Step 1: Add the test class** + +Append this class before `TestNodeMappings` in `tests/test_project_loader.py`: + +```python +class TestBinaryIndexDecoder: + def test_input_types(self): + from project_loader import BinaryIndexDecoder + inputs = BinaryIndexDecoder.INPUT_TYPES() + assert "index" in inputs["required"] + assert inputs["required"]["index"][0] == "INT" + + def test_three_boolean_outputs(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder.RETURN_TYPES == ("BOOLEAN", "BOOLEAN", "BOOLEAN") + assert BinaryIndexDecoder.RETURN_NAMES == ("flag_0", "flag_1", "flag_2") + + def test_category(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder.CATEGORY == "JSON Manager/utils" + + def test_index_0(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(0) == (False, False, False) + + def test_index_1(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(1) == (True, False, False) + + def test_index_2(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(2) == (False, True, False) + + def test_index_3(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(3) == (True, True, False) + + def test_index_4(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(4) == (False, False, True) + + def test_index_7(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(7) == (True, True, True) +``` + +Also update `TestNodeMappings.test_mappings_exist`: + +```python +def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS + assert "BinaryIndexDecoder" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 5 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 5 +``` + +**Step 2: Run tests to verify they fail** + +```bash +python -m pytest tests/test_project_loader.py::TestBinaryIndexDecoder -v +``` + +Expected: FAIL with `ImportError: cannot import name 'BinaryIndexDecoder'` + +**Step 3: Commit the failing tests** + +```bash +git add tests/test_project_loader.py +git commit -m "test: add failing tests for BinaryIndexDecoder node" +``` + +--- + +### Task 2: Implement BinaryIndexDecoder + +**Files:** +- Modify: `project_loader.py` — add class after `ProjectResolution`, update mappings + +**Step 1: Add the class** + +Insert after the `ProjectResolution` class (before `# --- Mappings ---`) in `project_loader.py`: + +```python +class BinaryIndexDecoder: + """Decodes an integer index into 3 boolean flags using binary (bit-field) encoding. + + index 0 → (False, False, False) + index 1 → (True, False, False) # bit 0 + index 2 → (False, True, False) # bit 1 + index 3 → (True, True, False) # bits 0+1 + index 4 → (False, False, True) # bit 2 + ... + index 7 → (True, True, True) + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 7}), + } + } + + RETURN_TYPES = ("BOOLEAN", "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ("flag_0", "flag_1", "flag_2") + FUNCTION = "decode" + CATEGORY = "JSON Manager/utils" + OUTPUT_NODE = False + + def decode(self, index: int): + return ( + bool((index >> 0) & 1), + bool((index >> 1) & 1), + bool((index >> 2) & 1), + ) +``` + +**Step 2: Update mappings** + +In `PROJECT_NODE_CLASS_MAPPINGS`, add: +```python +"BinaryIndexDecoder": BinaryIndexDecoder, +``` + +In `PROJECT_NODE_DISPLAY_NAME_MAPPINGS`, add: +```python +"BinaryIndexDecoder": "Binary Index Decoder", +``` + +**Step 3: Run all tests** + +```bash +python -m pytest tests/test_project_loader.py -v +``` + +Expected: all tests PASS (42 existing + 10 new = 52 total) + +**Step 4: Commit** + +```bash +git add project_loader.py tests/test_project_loader.py +git commit -m "feat: add BinaryIndexDecoder node (INT index → 3 BOOLEANs, binary encoding)" +git push +``` diff --git a/project_loader.py b/project_loader.py index 8827191..6830b41 100644 --- a/project_loader.py +++ b/project_loader.py @@ -67,6 +67,13 @@ def _fetch_json(url: str) -> dict: return {"error": "parse_error", "message": str(e)} +def _fetch_project(manager_url: str, project: str) -> dict: + """Fetch project details (including folder_path) from the NiceGUI REST API.""" + p = urllib.parse.quote(project, safe='') + url = f"{manager_url.rstrip('/')}/api/projects/{p}" + return _fetch_json(url) + + def _fetch_data(manager_url: str, project: str, file: str, seq: int) -> dict: """Fetch sequence data from the NiceGUI REST API.""" p = urllib.parse.quote(project, safe='') @@ -150,7 +157,7 @@ class ProjectLoaderDynamic: RETURN_TYPES = ("INT",) + tuple(any_type for _ in range(MAX_DYNAMIC_OUTPUTS)) RETURN_NAMES = ("total_sequences",) + tuple(f"output_{i}" for i in range(MAX_DYNAMIC_OUTPUTS)) FUNCTION = "load_dynamic" - CATEGORY = "utils/json/project" + CATEGORY = "JSON Manager/project" OUTPUT_NODE = False def load_dynamic(self, manager_url, project_name, file_name, sequence_number, @@ -221,14 +228,24 @@ class ProjectSource: }, } - RETURN_TYPES = ("INT", "STRING",) - RETURN_NAMES = ("sequence_number", "file_name",) + RETURN_TYPES = ("INT", "STRING", "STRING") + RETURN_NAMES = ("sequence_number", "file_name", "project_path") FUNCTION = "hold_config" - CATEGORY = "utils/json/project" + CATEGORY = "JSON Manager/project" OUTPUT_NODE = True def hold_config(self, manager_url, project_name, file_name, sequence_number, label): - return (sequence_number, file_name,) + name = project_name.strip() + if not name: + active = _fetch_json(f"{manager_url.rstrip('/')}/api/active-project") + name = active.get("project", "") if "error" not in active else "" + folder_path = "" + if name: + proj = _fetch_project(manager_url, name) + folder_path = proj.get("folder_path", "") if "error" not in proj else "" + if folder_path and not folder_path.endswith("/"): + folder_path += "/" + return (sequence_number, file_name, folder_path) class ProjectKey: @@ -252,7 +269,7 @@ class ProjectKey: RETURN_TYPES = (any_type,) RETURN_NAMES = ("value",) FUNCTION = "fetch_key" - CATEGORY = "utils/json/project" + CATEGORY = "JSON Manager/project" OUTPUT_NODE = False @classmethod @@ -282,26 +299,122 @@ class ProjectKey: val = data.get(key_name, "") if key_type == "INT": - return (to_int(val),) + result = to_int(val) + return {"ui": {"value": [str(result)]}, "result": (result,)} elif key_type == "FLOAT": - return (to_float(val),) + result = to_float(val) + return {"ui": {"value": [f"{result:.4g}"]}, "result": (result,)} elif isinstance(val, bool): - return (str(val).lower(),) + return {"ui": {"value": [str(val).lower()]}, "result": (str(val).lower(),)} elif isinstance(val, (int, float)): - return (val,) + return {"ui": {"value": [str(val)]}, "result": (val,)} else: return (str(val),) +class ProjectResolution: + """Fetches a (width, height) pair from a resolution series by loop index.""" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_label": ("STRING", {"default": "", "multiline": False}), + "key_name": ("STRING", {"default": "resolutions", "multiline": False}), + "index": ("INT", {"default": 0, "min": 0, "max": 9999}), + }, + "optional": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + } + + RETURN_TYPES = ("INT", "INT", "INT") + RETURN_NAMES = ("width", "height", "seed") + FUNCTION = "fetch_resolution" + CATEGORY = "JSON Manager/project" + OUTPUT_NODE = False + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("nan") + + def fetch_resolution(self, source_label, key_name, index, + manager_url="http://localhost:8080", project_name="", + file_name="", sequence_number=1): + sequence_number = int(sequence_number) + logger.info("ProjectResolution.fetch_resolution: source=%s key=%s url=%s project=%s file=%s seq=%s index=%s", + source_label, key_name, manager_url, project_name, file_name, sequence_number, index) + # source_label is used by JS to identify which ProjectSource to sync + # config from. The actual config arrives via the optional widgets below. + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message")) + return (512, 512, 0) + + series = data.get(key_name) + if not isinstance(series, list) or len(series) == 0: + logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name) + return (512, 512, 0) + + clamped = max(0, min(index, len(series) - 1)) + entry = series[clamped] + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry) + return (512, 512, 0) + + seed = to_int(entry[2]) if len(entry) >= 3 else 0 + return (to_int(entry[0]), to_int(entry[1]), seed) + + +class BinaryIndexDecoder: + """Decodes an integer index into 3 boolean flags using binary (bit-field) encoding. + + index 0 → (False, False, False) + index 1 → (True, False, False) # bit 0 + index 2 → (False, True, False) # bit 1 + index 3 → (True, True, False) # bits 0+1 + index 4 → (False, False, True) # bit 2 + ... + index 7 → (True, True, True) + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "index": ("INT", {"default": 0, "min": 0, "max": 7}), + } + } + + RETURN_TYPES = ("BOOLEAN", "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ("flag_0", "flag_1", "flag_2") + FUNCTION = "decode" + CATEGORY = "JSON Manager/utils" + OUTPUT_NODE = False + + def decode(self, index: int): + f0 = bool((index >> 0) & 1) + f1 = bool((index >> 1) & 1) + f2 = bool((index >> 2) & 1) + return {"ui": {"values": [str(f0).lower(), str(f1).lower(), str(f2).lower()]}, + "result": (f0, f1, f2)} + + # --- Mappings --- PROJECT_NODE_CLASS_MAPPINGS = { "ProjectLoaderDynamic": ProjectLoaderDynamic, "ProjectSource": ProjectSource, "ProjectKey": ProjectKey, + "ProjectResolution": ProjectResolution, + "BinaryIndexDecoder": BinaryIndexDecoder, } PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { "ProjectLoaderDynamic": "Project Loader (Dynamic)", "ProjectSource": "Project Source", "ProjectKey": "Project Key", + "ProjectResolution": "Project Resolution", + "BinaryIndexDecoder": "Binary Index Decoder", } diff --git a/tab_batch_ng.py b/tab_batch_ng.py index c0bd53c..69ab195 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -6,6 +6,7 @@ import math import random import time from pathlib import Path +from urllib.parse import quote from nicegui import ui @@ -314,9 +315,12 @@ def render_batch_processor(state: AppState): standard_keys = { 'name', 'mode', 'general_prompt', 'general_negative', 'current_prompt', 'negative', 'prompt', 'seed', 'cfg', 'camera', 'flf', KEY_SEQUENCE_NUMBER, - 'frame_to_skip', 'end_frame', 'transition', 'vace_length', + 'frame_to_skip', 'end_frame', 'logic index', 'transition', 'vace_length', 'input_a_frames', 'input_b_frames', 'reference switch', 'vace schedule', - 'reference path', 'video file path', 'reference image path', 'flf image path', + 'start frame path', 'start frame strength', + 'middle frame path', 'middle frame strength', + 'end frame path', 'end frame strength', + 'video file path', } standard_keys.update(lora_keys) @@ -409,6 +413,7 @@ def render_batch_processor(state: AppState): # Single sequence card # ====================================================================== + def _render_sequence_card(i, seq, batch_list, data, file_path, state, src_cache, src_seq_select, standard_keys, refresh_list): @@ -469,11 +474,11 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, ) if result is not None: s['name'] = result - commit('Renamed!') + await commit('Renamed!') ui.button('Rename', icon='edit', on_click=rename).props('outline') # Copy from source - def copy_source(idx=i, sn=seq_num): + async def copy_source(idx=i, sn=seq_num): item = copy.deepcopy(DEFAULTS) src_batch = src_cache['batch'] sel_idx = src_seq_select.value @@ -485,12 +490,12 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, item.pop(KEY_PROMPT_HISTORY, None) item.pop(KEY_HISTORY_TREE, None) batch_list[idx] = item - commit('Copied!') + await commit('Copied!') ui.button('Copy Src', icon='file_download', on_click=copy_source).props('outline') # Clone Next - def clone_next(idx=i, sn=seq_num, s=seq): + async def clone_next(idx=i, sn=seq_num, s=seq): new_seq = copy.deepcopy(s) new_seq[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1 if not is_subsegment(sn): @@ -498,21 +503,21 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, else: pos = idx + 1 batch_list.insert(pos, new_seq) - commit('Cloned to Next!') + await commit('Cloned to Next!') ui.button('Clone Next', icon='content_copy', on_click=clone_next).props('outline') # Clone End - def clone_end(s=seq): + async def clone_end(s=seq): new_seq = copy.deepcopy(s) new_seq[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1 batch_list.append(new_seq) - commit('Cloned to End!') + await commit('Cloned to End!') ui.button('Clone End', icon='vertical_align_bottom', on_click=clone_end).props('outline') # Clone Sub - def clone_sub(idx=i, sn=seq_num, s=seq): + async def clone_sub(idx=i, sn=seq_num, s=seq): new_seq = copy.deepcopy(s) p_seq = parent_of(sn) p_idx = idx @@ -524,23 +529,24 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, new_seq[KEY_SEQUENCE_NUMBER] = next_sub_segment_number(batch_list, p_seq) pos = find_insert_position(batch_list, p_idx, p_seq) batch_list.insert(pos, new_seq) - commit(f'Created {format_seq_label(new_seq[KEY_SEQUENCE_NUMBER])}!') + await commit(f'Created {format_seq_label(new_seq[KEY_SEQUENCE_NUMBER])}!') ui.button('Clone Sub', icon='link', on_click=clone_sub).props('outline') ui.element('div').classes('col') # Delete - def delete(idx=i): + async def delete(idx=i): if idx < len(batch_list): batch_list.pop(idx) - commit() + await commit() ui.button(icon='delete', on_click=delete).props('color=negative') ui.separator() # --- Prompts + Settings (2-column) --- + frame_switches = [] # populated below, used for bidirectional sync with logic index with ui.splitter(value=66).classes('w-full') as splitter: with splitter.before: dict_textarea('General Prompt', seq, 'general_prompt').classes( @@ -552,6 +558,35 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, dict_textarea('Specific Negative', seq, 'negative').classes( 'w-full q-mt-sm').props('outlined rows=2') + # --- Frame paths (start / middle / end) --- + logic_val = int(seq.get('logic index', 0)) + for bit, img_label, img_key, str_key in [ + (0, 'Start Frame', 'start frame path', 'start frame strength'), + (1, 'Middle Frame', 'middle frame path', 'middle frame strength'), + (2, 'End Frame', 'end frame path', 'end frame strength'), + ]: + ui.label(img_label).classes('text-caption text-weight-bold q-mt-sm') + with ui.row().classes('w-full items-center no-wrap q-mt-xs'): + inp = dict_input(ui.input, 'Path', seq, img_key).classes( + 'col').props('outlined dense input-style="text-align: right"') + img_path = Path(seq.get(img_key, '')) if seq.get(img_key) else None + if (img_path and img_path.exists() and + img_path.suffix.lower() in IMAGE_EXTENSIONS): + img_url = f'/api/image-preview?path={quote(str(img_path))}' + with ui.dialog() as img_dlg, ui.card().style('max-width:90vw; padding:0'): + ui.html(f'') + ui.html( + f'' + ).on('click', img_dlg.open) + str_inp = dict_number('Strength', seq, str_key, default=1.0, + step=0.05, format='%.2f').style( + 'width:80px').props('outlined dense') + sw = ui.switch(value=bool((logic_val >> bit) & 1)) + frame_switches.append(sw) + with splitter.after: # Mode dict_number('Mode', seq, 'mode').props('outlined').classes('w-full') @@ -581,25 +616,68 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, dict_input(ui.input, 'Camera', seq, 'camera').props('outlined').classes('w-full') dict_input(ui.input, 'FLF', seq, 'flf').props('outlined').classes('w-full') - dict_number('End Frame', seq, 'end_frame').props('outlined').classes('w-full') + ef_input = dict_number('End Frame', seq, 'end_frame').props('outlined').classes('w-full') + seq.setdefault('logic index', 0) + li_input = dict_number('Logic Index', seq, 'logic index').props('outlined readonly').classes('w-full') + with li_input: + ui.tooltip( + 'Binary flags — bit 0: start frame | bit 1: middle frame | bit 2: end frame\n' + '0: none 1: start 2: middle 3: start+middle\n' + '4: end 5: start+end 6: middle+end 7: all' + ) dict_input(ui.input, 'Video File Path', seq, 'video file path').props( - 'outlined input-style="direction: rtl"').classes('w-full') + 'outlined input-style="text-align: right"').classes('w-full') - # Image paths with preview - for img_label, img_key in [ - ('Reference Image Path', 'reference image path'), - ('Reference Path', 'reference path'), - ('FLF Image Path', 'flf image path'), - ]: - with ui.row().classes('w-full items-center'): - inp = dict_input(ui.input, img_label, seq, img_key).classes( - 'col').props('outlined input-style="direction: rtl"') - img_path = Path(seq.get(img_key, '')) if seq.get(img_key) else None - if (img_path and img_path.exists() and - img_path.suffix.lower() in IMAGE_EXTENSIONS): - with ui.dialog() as dlg, ui.card(): - ui.image(str(img_path)).classes('w-full') - ui.button(icon='visibility', on_click=dlg.open).props('flat dense') + # Switches → logic index (sole writer) + def _sync_switches_to_logic(li=li_input, switches=frame_switches, s=seq): + v = sum(int(sw.value) << b for b, sw in enumerate(switches)) + s['logic index'] = v + li.set_value(v) + + for frame_sw in frame_switches: + frame_sw.on('update:model-value', lambda _, s=_sync_switches_to_logic: s()) + + # --- Resolutions (8 fixed slots) --- + resolutions = seq.setdefault('resolutions', []) + while len(resolutions) < 8: + resolutions.append([512, 512, 0]) + for r_i in range(len(resolutions)): + if len(resolutions[r_i]) < 3: + resolutions[r_i] = list(resolutions[r_i]) + [0] + with ui.expansion('Resolutions', icon='aspect_ratio').classes('w-full'): + for idx in range(8): + entry = resolutions[idx] + with ui.row().classes('items-center w-full q-mt-xs no-wrap'): + ui.label(str(idx)).classes('text-caption').style('min-width:16px') + w_inp = ui.number(value=int(entry[0]), min=1, step=1, label='W').style( + 'width:70px').props('outlined dense hide-bottom-space') + h_inp = ui.number(value=int(entry[1]), min=1, step=1, label='H').style( + 'width:70px').props('outlined dense hide-bottom-space') + seed_inp = ui.number(value=int(entry[2]), min=0, step=1, label='Seed').style( + 'flex:1; min-width:60px').props('outlined dense hide-bottom-space') + + async def _sync_entry(r=idx, wi=w_inp, hi=h_inp, si=seed_inp): + seq['resolutions'][r] = [ + int(wi.value) if wi.value else 512, + int(hi.value) if hi.value else 512, + int(si.value) if si.value else 0, + ] + await commit() + + async def _randomize(si=seed_inp, r=idx): + si.value = random.randint(0, 2**32 - 1) + seq['resolutions'][r][2] = int(si.value) + await commit() + + ui.button(icon='casino', on_click=_randomize).props( + 'flat dense round').classes('q-ml-xs') + + w_inp.on('blur', lambda _, s=_sync_entry: s()) + w_inp.on('update:model-value', lambda _, s=_sync_entry: s()) + h_inp.on('blur', lambda _, s=_sync_entry: s()) + h_inp.on('update:model-value', lambda _, s=_sync_entry: s()) + seed_inp.on('blur', lambda _, s=_sync_entry: s()) + seed_inp.on('update:model-value', lambda _, s=_sync_entry: s()) # --- VACE Settings (full width) --- with ui.expansion('VACE Settings', icon='settings').classes('w-full'): @@ -645,16 +723,16 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, # --- Custom Parameters --- ui.label('Custom Parameters').classes('section-header q-mt-md') - custom_keys = [k for k in seq.keys() if k not in standard_keys] + custom_keys = [k for k in seq.keys() if k not in standard_keys and k != 'resolutions'] if custom_keys: for k in custom_keys: with ui.row().classes('w-full items-center'): ui.input('Key', value=k).props('readonly outlined dense').classes('w-32') dict_input(ui.input, 'Value', seq, k).props('outlined dense').classes('col') - def del_custom(key=k): + async def del_custom(key=k): del seq[key] - commit() + await commit() ui.button(icon='delete', on_click=del_custom).props('flat dense color=negative') @@ -662,14 +740,14 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, new_k_input = ui.input('Key').props('outlined dense') new_v_input = ui.input('Value').props('outlined dense') - def add_param(): + async def add_param(): k = new_k_input.value v = new_v_input.value if k and k not in seq: seq[k] = v new_k_input.set_value('') new_v_input.set_value('') - commit() + await commit() ui.button('Add', on_click=add_param).props('flat') diff --git a/tab_projects_ng.py b/tab_projects_ng.py index ddc1570..eca356f 100644 --- a/tab_projects_ng.py +++ b/tab_projects_ng.py @@ -216,8 +216,10 @@ def render_projects_tab(state: AppState): async def _import_folder(state: AppState, project_id: int, project_name: str, refresh_fn): - """Bulk import all .json files from current directory into a project.""" - json_files = sorted(state.current_dir.glob('*.json')) + """Bulk import all .json files from the project's folder_path into a project.""" + proj = state.db.get_project(project_name) + scan_dir = Path(proj['folder_path']) if proj else state.current_dir + json_files = sorted(scan_dir.glob('*.json')) json_files = [f for f in json_files if f.name not in ( '.editor_config.json', '.editor_snippets.json')] diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index 3986a2e..9c48d85 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -602,6 +602,36 @@ def _render_preview_fields(item_data: dict): ui.input('Video Path', value=str(item_data.get('video file path', ''))).props('readonly outlined') + resolutions = item_data.get('resolutions') + if isinstance(resolutions, list) and resolutions: + with ui.expansion('Resolutions'): + with ui.grid(columns=4).classes('w-full'): + for i, entry in enumerate(resolutions): + if isinstance(entry, (list, tuple)) and len(entry) >= 2: + w, h = entry[0], entry[1] + seed = entry[2] if len(entry) >= 3 else 0 + ui.input(f'#{i} W', value=str(w)).props('readonly outlined dense') + ui.input(f'#{i} H', value=str(h)).props('readonly outlined dense') + ui.input(f'#{i} Seed', value=str(seed)).props('readonly outlined dense') + ui.label('') # grid spacer for 4th column + + known_keys = { + 'sequence_number', 'general_prompt', 'general_negative', 'current_prompt', 'prompt', + 'negative', 'camera', 'flf', 'seed', 'resolutions', + 'frame_to_skip', 'vace schedule', 'video file path', 'middle frame path', 'end frame path', 'start frame path', + 'logic index', + } + # also skip lora keys + custom_keys = [ + k for k in item_data + if k not in known_keys and not k.startswith('lora ') + ] + if custom_keys: + with ui.expansion('Custom Fields'): + with ui.grid(columns=2).classes('w-full'): + for k in custom_keys: + ui.input(k, value=str(item_data[k])).props('readonly outlined dense') + def _truncate(val, max_len=60): """Truncate a value for display.""" diff --git a/tests/test_project_loader.py b/tests/test_project_loader.py index 153b88a..956325d 100644 --- a/tests/test_project_loader.py +++ b/tests/test_project_loader.py @@ -200,7 +200,7 @@ class TestProjectLoaderDynamic: assert "sequence_number" in inputs["required"] def test_category(self): - assert ProjectLoaderDynamic.CATEGORY == "utils/json/project" + assert ProjectLoaderDynamic.CATEGORY == "JSON Manager/project" class TestProjectSource: @@ -232,7 +232,7 @@ class TestProjectSource: def test_category(self): from project_loader import ProjectSource - assert ProjectSource.CATEGORY == "utils/json/project" + assert ProjectSource.CATEGORY == "JSON Manager/project" class TestProjectKey: @@ -341,7 +341,159 @@ class TestProjectKey: def test_category(self): from project_loader import ProjectKey - assert ProjectKey.CATEGORY == "utils/json/project" + assert ProjectKey.CATEGORY == "JSON Manager/project" + + +class TestProjectResolution: + def test_input_types(self): + from project_loader import ProjectResolution + inputs = ProjectResolution.INPUT_TYPES() + assert "source_label" in inputs["required"] + assert "key_name" in inputs["required"] + assert "index" in inputs["required"] + assert inputs["required"]["index"][0] == "INT" + + def test_three_outputs(self): + from project_loader import ProjectResolution + assert ProjectResolution.RETURN_TYPES == ("INT", "INT", "INT") + assert ProjectResolution.RETURN_NAMES == ("width", "height", "seed") + + def test_fetch_resolution_basic(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512, 0], [768, 1344, 12345], [1344, 768, 99]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=1, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (768, 1344, 12345) + + def test_fetch_resolution_index_zero(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512, 42], [1024, 1024, 0]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512, 42) + + def test_fetch_resolution_clamps_on_out_of_bounds(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512, 0], [1024, 1024, 7]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=99, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (1024, 1024, 7) # last entry + + def test_fetch_resolution_old_format_no_seed(self): + """Old [w, h] entries without seed should return seed=0.""" + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[576, 384], [960, 640]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (576, 384, 0) + + def test_fetch_resolution_missing_key_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + with patch("project_loader._fetch_data", return_value={}): + result = node.fetch_resolution( + source_label="src", key_name="nonexistent", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512, 0) + + def test_fetch_resolution_network_error_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_data", return_value=error_resp): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512, 0) + + def test_fetch_resolution_malformed_entry_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512]]} # single-element, not a valid pair + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512, 0) + + def test_category(self): + from project_loader import ProjectResolution + assert ProjectResolution.CATEGORY == "JSON Manager/project" + + +class TestBinaryIndexDecoder: + def test_input_types(self): + from project_loader import BinaryIndexDecoder + inputs = BinaryIndexDecoder.INPUT_TYPES() + assert "index" in inputs["required"] + assert inputs["required"]["index"][0] == "INT" + + def test_three_boolean_outputs(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder.RETURN_TYPES == ("BOOLEAN", "BOOLEAN", "BOOLEAN") + assert BinaryIndexDecoder.RETURN_NAMES == ("flag_0", "flag_1", "flag_2") + + def test_category(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder.CATEGORY == "JSON Manager/utils" + + def test_index_0(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(0) == (False, False, False) + + def test_index_1(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(1) == (True, False, False) + + def test_index_2(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(2) == (False, True, False) + + def test_index_3(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(3) == (True, True, False) + + def test_index_4(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(4) == (False, False, True) + + def test_index_5(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(5) == (True, False, True) + + def test_index_6(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(6) == (False, True, True) + + def test_index_7(self): + from project_loader import BinaryIndexDecoder + assert BinaryIndexDecoder().decode(7) == (True, True, True) class TestNodeMappings: @@ -350,5 +502,7 @@ class TestNodeMappings: assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS - assert len(PROJECT_NODE_CLASS_MAPPINGS) == 3 - assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 3 + assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS + assert "BinaryIndexDecoder" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 5 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 5 diff --git a/utils.py b/utils.py index ab5ed85..b9f558c 100644 --- a/utils.py +++ b/utils.py @@ -38,6 +38,7 @@ DEFAULTS = { # --- I2V / VACE Specifics --- "frame_to_skip": 81, "end_frame": 0, + "logic index": 0, "transition": "1-2", "vace_length": 49, "vace schedule": 1, @@ -45,9 +46,12 @@ DEFAULTS = { "input_b_frames": 16, "reference switch": 1, "video file path": "", - "reference image path": "", - "reference path": "", - "flf image path": "", + "start frame path": "", + "start frame strength": 1.0, + "middle frame path": "", + "middle frame strength": 1.0, + "end frame path": "", + "end frame strength": 1.0, # --- LoRAs (name as STRING, strength as FLOAT) --- "lora 1 high": "", @@ -150,6 +154,19 @@ def save_snippets(snippets): json.dump(snippets, f, indent=4) os.replace(tmp, SNIPPETS_FILE) +def _migrate_key_renames(data: dict) -> None: + """Rename legacy keys to their current names.""" + for item in data.get(KEY_BATCH_DATA, []): + if not isinstance(item, dict): + continue + if 'reference path' in item and 'middle frame path' not in item: + item['middle frame path'] = item.pop('reference path') + if 'flf image path' in item and 'end frame path' not in item: + item['end frame path'] = item.pop('flf image path') + if 'reference image path' in item and 'start frame path' not in item: + item['start frame path'] = item.pop('reference image path') + + def _migrate_lora_keys(data: dict) -> None: """Split combined lora 'name:strength' into separate name and strength keys. @@ -208,6 +225,7 @@ def load_json(path: str | Path) -> tuple[dict[str, Any], float]: with open(path, 'r') as f: data = json.load(f) t1 = time.time() + _migrate_key_renames(data) _migrate_lora_keys(data) t2 = time.time() mtime = path.stat().st_mtime diff --git a/web/binary_index_decoder.js b/web/binary_index_decoder.js new file mode 100644 index 0000000..5304e65 --- /dev/null +++ b/web/binary_index_decoder.js @@ -0,0 +1,20 @@ +import { app } from "../../scripts/app.js"; + +app.registerExtension({ + name: "json.manager.binary_index_decoder", + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "BinaryIndexDecoder") return; + + nodeType.prototype.onExecuted = function (output) { + if (!output?.values) return; + for (let i = 0; i < Math.min(output.values.length, this.outputs.length); i++) { + const val = output.values[i]; + this.outputs[i].label = `${val} ${this.outputs[i].name}`; + this.outputs[i].color_on = (val === "true") ? "#4caf50" : "#888888"; + this.outputs[i].color_off = (val === "true") ? "#4caf50" : "#888888"; + } + app.graph?.setDirtyCanvas(true, true); + }; + }, +}); diff --git a/web/project_key.js b/web/project_key.js index 325daca..2ce42b3 100644 --- a/web/project_key.js +++ b/web/project_key.js @@ -201,6 +201,60 @@ app.registerExtension({ app.graph?.setDirtyCanvas(true, true); }; + // --- Show live value on output slot after execution (INT/FLOAT/BOOL only) --- + nodeType.prototype.onExecuted = function (output) { + if (!this.outputs.length) return; + const val = output?.value?.[0]; + if (val === undefined) return; + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + const name = keyWidget?.value || this.outputs[0].name; + this.outputs[0].label = `${val} ${name}`; + const slotType = this.outputs[0].type; + const TYPE_COLORS = { "INT": "#3d7eb5", "FLOAT": "#68a468", "BOOLEAN": null }; + let color; + if (slotType === "BOOLEAN") { + color = (val === "true") ? "#4caf50" : "#888888"; + } else { + color = TYPE_COLORS[slotType] + ?? LGraphCanvas?.link_type_colors?.[slotType] + ?? app.canvas?.default_connection_color_byType?.[slotType]; + } + if (color) { + this.outputs[0].color_on = color; + this.outputs[0].color_off = color; + } + app.graph?.setDirtyCanvas(true, true); + }; + + // --- Highlight all ProjectKey nodes sharing the same key_name on select --- + nodeType.prototype.onSelected = function () { + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + const myKey = keyWidget?.value; + if (!myKey || !this.graph) return; + for (const node of this.graph._nodes) { + if (node === this || node.type !== "ProjectKey") continue; + const kw = node.widgets?.find(w => w.name === "key_name"); + if (kw?.value !== myKey) continue; + node._savedColor = node.color; + node._savedBgColor = node.bgcolor; + node.color = "#c8a000"; + node.bgcolor = "#4a3800"; + } + app.graph?.setDirtyCanvas(true, true); + }; + + nodeType.prototype.onDeselected = function () { + if (!this.graph) return; + for (const node of this.graph._nodes) { + if (node.type !== "ProjectKey" || !("_savedColor" in node)) continue; + node.color = node._savedColor; + node.bgcolor = node._savedBgColor; + delete node._savedColor; + delete node._savedBgColor; + } + app.graph?.setDirtyCanvas(true, true); + }; + // --- Sync config on click (lazy, no key refresh to avoid race) --- const origOnMouseDown = nodeType.prototype.onMouseDown; nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) { diff --git a/web/project_resolution.js b/web/project_resolution.js new file mode 100644 index 0000000..b6b95e9 --- /dev/null +++ b/web/project_resolution.js @@ -0,0 +1,191 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.resolution", + + async beforeQueuePrompt() { + if (!app.graph?._nodes) return; + for (const node of app.graph._nodes) { + if (node.type === "ProjectResolution" && node._syncFromSource) { + node._syncFromSource(); + } + } + }, + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectResolution") return; + + function hideWidget(widget) { + if (widget.origType === undefined) widget.origType = widget.type; + widget.type = "hidden"; + widget.hidden = true; + widget.computeSize = () => [0, -4]; + } + + function replaceWithCombo(node, name, values, callback) { + const idx = node.widgets?.findIndex(w => w.name === name); + if (idx === -1 || idx === undefined) return null; + const oldWidget = node.widgets[idx]; + const savedValue = oldWidget.value || ""; + const comboValues = values.length > 0 ? values : [""]; + if (savedValue && !comboValues.includes(savedValue)) { + comboValues.unshift(savedValue); + } + const defaultValue = savedValue || comboValues[0]; + node.widgets.splice(idx, 1); + const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues }); + if (node.widgets.length > 1) { + node.widgets.splice(node.widgets.length - 1, 1); + node.widgets.splice(idx, 0, combo); + } + return combo; + } + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + this._configured = false; + + // Hide synced config widgets — index stays visible, user wires it from loop node + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const node = this; + const sourceLabels = this._getSourceLabels?.() || []; + const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + if (srcCombo) srcCombo.value = sourceLabels[0] || ""; + + const keyCombo = replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + if (keyCombo && !keyCombo.value) keyCombo.value = "resolutions"; + + queueMicrotask(() => { + if (!this._configured) { + this.setSize(this.computeSize()); + } + }); + }; + + nodeType.prototype._getSourceLabels = function () { + const seen = new Set(); + const labels = []; + if (!this.graph) return labels; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value && !seen.has(lw.value)) { + seen.add(lw.value); + labels.push(lw.value); + } + } + } + return labels; + }; + + nodeType.prototype._findSource = function (label) { + if (!this.graph || !label) return null; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value === label) return node; + } + } + return null; + }; + + nodeType.prototype._syncFromSource = function () { + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + const source = this._findSource(srcWidget?.value); + if (!source) return; + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const dst = this.widgets?.find(w => w.name === name); + const src = source.widgets?.find(w => w.name === name); + if (dst && src) dst.value = src.value; + } + }; + + nodeType.prototype._refreshKeys = async function () { + const urlW = this.widgets?.find(w => w.name === "manager_url"); + const projW = this.widgets?.find(w => w.name === "project_name"); + const fileW = this.widgets?.find(w => w.name === "file_name"); + const seqW = this.widgets?.find(w => w.name === "sequence_number"); + if (!urlW?.value || !projW?.value || !fileW?.value) return; + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}` + ); + if (!resp.ok) return; + const data = await resp.json(); + if (data.error || !Array.isArray(data.keys)) return; + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget) { + keyWidget.options.values = data.keys.length > 0 ? data.keys : [""]; + } + } catch (e) { + console.error("[ProjectResolution] Failed to refresh keys:", e); + } + }; + + const origOnMouseDown = nodeType.prototype.onMouseDown; + nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) { + origOnMouseDown?.apply(this, arguments); + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget) srcWidget.options.values = this._getSourceLabels(); + this._syncFromSource(); + }; + + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + this._configured = true; + + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget && srcWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + } else if (srcWidget) { + srcWidget.options.values = this._getSourceLabels(); + } + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget && keyWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + } + + const finalKeyWidget = this.widgets?.find(w => w.name === "key_name"); + if (finalKeyWidget?.value) { + this.title = `Resolution: ${finalKeyWidget.value}`; + } + + this.setSize(this.computeSize()); + + const node = this; + queueMicrotask(() => { + node._syncFromSource(); + node._refreshKeys(); + }); + }; + }, +}); diff --git a/web/project_source.js b/web/project_source.js index cde2a30..e9fb699 100644 --- a/web/project_source.js +++ b/web/project_source.js @@ -28,6 +28,35 @@ app.registerExtension({ return combo; } + // Fetch active project from Manager and update project_name + title + async function refreshActiveProject(node) { + const urlW = node.widgets?.find(w => w.name === "manager_url"); + if (!urlW?.value) return; + try { + const resp = await fetch(`${urlW.value}/api/active-project`); + if (!resp.ok) return; + const data = await resp.json(); + const project = data.project || ""; + const projW = node.widgets?.find(w => w.name === "project_name"); + if (projW && projW.value !== project) { + projW.value = project; + await refreshFiles(node); + } + _updateTitle(node); + } catch (e) { + console.warn("[ProjectSource] Failed to fetch active project:", e); + } + } + + function _updateTitle(node) { + const labelW = node.widgets?.find(w => w.name === "label"); + const projW = node.widgets?.find(w => w.name === "project_name"); + const label = labelW?.value || ""; + const project = projW?.value || "?"; + node.title = label ? `Source: ${label} [${project}]` : `Project Source [${project}]`; + app.graph?.setDirtyCanvas(true, true); + } + // Fetch file list from API and update file_name combo async function refreshFiles(node) { const urlW = node.widgets?.find(w => w.name === "manager_url"); @@ -84,22 +113,28 @@ app.registerExtension({ const node = this; + // Hide project_name — it is auto-filled from the Manager's active project + const projW = this.widgets?.find(w => w.name === "project_name"); + if (projW) { + if (projW.origType === undefined) projW.origType = projW.type; + projW.type = "hidden"; + projW.hidden = true; + projW.computeSize = () => [0, -4]; + } + // Replace file_name STRING with a combo replaceWithCombo(this, "file_name", [], function (value) { notifyRelays(node); }); - // Hook manager_url and project_name to refresh file list + notify relays - for (const name of ["manager_url", "project_name"]) { - const w = this.widgets?.find(w => w.name === name); - if (w) { - const origCb = w.callback; - w.callback = function (...args) { - origCb?.apply(this, args); - refreshFiles(node); - notifyRelays(node); - }; - } + // Hook manager_url to refresh active project + files + notify relays + const urlW = this.widgets?.find(w => w.name === "manager_url"); + if (urlW) { + const origCb = urlW.callback; + urlW.callback = function (...args) { + origCb?.apply(this, args); + refreshActiveProject(node).then(() => notifyRelays(node)); + }; } // Hook sequence_number to notify relays @@ -118,22 +153,27 @@ app.registerExtension({ const origCallback = labelWidget.callback; labelWidget.callback = function (...args) { origCallback?.apply(this, args); - node.title = labelWidget.value - ? `Source: ${labelWidget.value}` - : "Project Source"; - app.graph?.setDirtyCanvas(true, true); + _updateTitle(node); }; - // Set initial title - if (labelWidget.value) { - this.title = `Source: ${labelWidget.value}`; - } } + + // Auto-fetch active project on creation + queueMicrotask(() => refreshActiveProject(node)); }; const origOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function (info) { origOnConfigure?.apply(this, arguments); + // Hide project_name (may have been serialized as visible) + const projW = this.widgets?.find(w => w.name === "project_name"); + if (projW) { + if (projW.origType === undefined) projW.origType = projW.type; + projW.type = "hidden"; + projW.hidden = true; + projW.computeSize = () => [0, -4]; + } + // Ensure file_name is a combo (may be STRING from serialization) const fileW = this.widgets?.find(w => w.name === "file_name"); if (fileW && fileW.type !== "combo") { @@ -143,16 +183,18 @@ app.registerExtension({ }); } - const labelWidget = this.widgets?.find(w => w.name === "label"); - if (labelWidget?.value) { - this.title = `Source: ${labelWidget.value}`; - } + _updateTitle(this); - // Deferred: refresh file list once graph is ready + // Deferred: fetch active project (and files) once graph is ready const node = this; - queueMicrotask(() => { - refreshFiles(node); - }); + queueMicrotask(() => refreshActiveProject(node)); + }; + + // Re-check active project on click (picks up changes made in the Manager) + const origOnMouseDown = nodeType.prototype.onMouseDown; + nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) { + origOnMouseDown?.apply(this, arguments); + refreshActiveProject(this); }; }, });