diff --git a/db.py b/db.py index c870e8f..81e9461 100644 --- a/db.py +++ b/db.py @@ -242,7 +242,6 @@ class ProjectDB: ) self.conn.commit() - @staticmethod @staticmethod def _migrate_lora_keys(data: dict) -> dict: """Split combined lora 'name:strength' into separate name and strength keys.""" @@ -340,27 +339,34 @@ class ProjectDB: # ------------------------------------------------------------------ def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: - """Save history tree, extracting node snapshots into separate table.""" + """Save history tree, extracting snapshot data into separate table. + + Supports both new format (snapshots dict) and old format (nodes dict). + """ now = time.time() - nodes = tree_data.get("nodes", {}) + if "snapshots" in tree_data: + entries = tree_data.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = tree_data.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(tree_data) - slim_nodes = {} - for nid, node in nodes.items(): - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries = {} + for eid, entry in entries.items(): + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute("BEGIN IMMEDIATE") try: - # Extract snapshot data from nodes into history_snapshots table - for nid, node in nodes.items(): - snap = node.get("data") + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (data_file_id, nid, json.dumps(snap), now), + (data_file_id, eid, json.dumps(snap), now), ) self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " @@ -463,24 +469,30 @@ class ProjectDB: ) # Import history tree (extract snapshots into separate table) + # Supports both new format (snapshots dict) and old format (nodes dict) history_tree = data.get(KEY_HISTORY_TREE) if history_tree and isinstance(history_tree, dict): now = time.time() - nodes = history_tree.get("nodes", {}) + if "snapshots" in history_tree: + entries = history_tree.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = history_tree.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(history_tree) - slim_nodes = {} - for nid, node in nodes.items(): - snap = node.get("data") + slim_entries = {} + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (df_id, nid, json.dumps(snap), now), + (df_id, eid, json.dumps(snap), now), ) - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " "VALUES (?, ?, ?) " @@ -540,9 +552,9 @@ class ProjectDB: # Load history tree (metadata only, no snapshot data) tree = self.get_history_tree(df["id"]) if tree: - # Strip any residual snapshot data from nodes - for node in tree.get("nodes", {}).values(): - node.pop("data", None) + # Strip any residual snapshot data (supports both formats) + for entry in tree.get("snapshots", tree.get("nodes", {})).values(): + entry.pop("data", None) data["history_tree"] = tree t3 = time.time() diff --git a/docs/plans/2026-04-03-resolution-series-design.md b/docs/plans/2026-04-03-resolution-series-design.md new file mode 100644 index 0000000..683d591 --- /dev/null +++ b/docs/plans/2026-04-03-resolution-series-design.md @@ -0,0 +1,81 @@ +# Resolution Series Design + +## Problem + +When running ComfyUI loop nodes for multi-step upscaling (e.g. 3+ resolutions at different sizes), +managing portrait vs landscape width/height per iteration is tedious. Users need a structured way +to define N resolution pairs in the manager UI and retrieve them by loop index in ComfyUI. + +## Design + +### Data Model + +Resolution series are stored as a JSON array under a user-chosen key in the sequence data: + +```json +"upscale_resolutions": [[512, 512], [768, 1344], [1344, 768], [2048, 2048]] +``` + +- Each element is `[width, height]` (both INT) +- Key name is chosen by the user (any string) +- Number of entries is configurable (add/remove rows) +- Stored in the same project JSON file and sequence — no schema change required +- Index out of bounds → clamp to last entry + +### NiceGUI UI (tab_batch_ng.py) + +A resolution series editor is rendered in the left column of the sequence card, directly below +the "Specific Negative" textarea. + +Layout: + +``` +── Resolution Series ────────────────── + key name: [upscale_resolutions ] + # Width Height + 1 [2048] [2048] [x] + 2 [768 ] [1344] [x] + 3 [1344] [768 ] [x] + [+ Add row] +``` + +- Key name is editable (defaults to `resolutions`) +- Rows added/removed inline; each change calls `commit()` immediately +- Hidden behind an "Add Resolution Series" button when no resolution key exists yet +- A value is detected as a resolution series if it is a list of `[int, int]` pairs + +### ComfyUI Node (`ProjectResolution`) + +New node class in `project_loader.py`, sibling to `ProjectKey`. + +**Inputs:** +- `source_label` (STRING) — references a `ProjectSource` by label +- `key_name` (STRING) — the resolution series key name +- `index` (INT, min 0) — wired from loop node's current index output +- `manager_url`, `project_name`, `file_name`, `sequence_number` — optional, synced from `ProjectSource` via JS + +**Outputs:** `width` (INT), `height` (INT) + +**Execution:** fetches the sequence data, reads `data[key_name]`, indexes into the array with +clamp-to-last on out-of-bounds, returns `(width, height)`. + +**JS (`web/project_resolution.js`):** +- Same `_syncFromSource` mechanism as `project_key.js` +- `key_name` widget is replaced with a combo dropdown populated with keys whose value is a + resolution series (list of `[int, int]` pairs), detected via the existing keys API +- Registered in `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS` + +### API + +No new endpoints. Uses existing: +- `/json_manager/get_project_keys` — for key discovery (JS combo population) +- `_fetch_data()` — for execution-time data fetch + +### Files Changed + +| File | Change | +|------|--------| +| `project_loader.py` | Add `ProjectResolution` class + register in mappings | +| `web/project_resolution.js` | New JS extension for the node | +| `tab_batch_ng.py` | Resolution series editor below Specific Negative | +| `__init__.py` | Register new JS file if needed | diff --git a/docs/plans/2026-04-03-resolution-series-plan.md b/docs/plans/2026-04-03-resolution-series-plan.md new file mode 100644 index 0000000..9e2cc9d --- /dev/null +++ b/docs/plans/2026-04-03-resolution-series-plan.md @@ -0,0 +1,640 @@ +# Resolution Series Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add a `ProjectResolution` ComfyUI node and NiceGUI editor that let users define N `(width, height)` pairs per sequence and retrieve them by loop index. + +**Architecture:** Resolution series are stored as a JSON array of `[width, height]` pairs under a user-chosen key in sequence data (e.g. `"upscale_resolutions": [[512,512],[768,1344]]`). A new `ProjectResolution` ComfyUI node (sibling of `ProjectKey`) accepts a `source_label`, `key_name`, and `index` INT from a loop node, and returns `width` + `height`. The NiceGUI sequence card gets an inline table editor placed directly below the "Specific Negative" textarea. + +**Tech Stack:** Python (ComfyUI node), NiceGUI (UI), JavaScript (ComfyUI frontend extension), pytest + +**Branch:** Create and work on `feat/resolution-series` branched from `main`: +```bash +git checkout main && git checkout -b feat/resolution-series +``` + +--- + +### Task 0: Fix pre-existing test failures on `main` + +When `file_name` was added as a second output to `ProjectSource`, two tests were not updated. +They fail on `main` before any new code is written. + +**Files:** +- Modify: `tests/test_project_loader.py` (`TestProjectSource` class, lines ~216-231) + +**Step 1: Update the two broken tests** + +```python +def test_outputs_sequence_number(self): + from project_loader import ProjectSource + assert ProjectSource.RETURN_TYPES == ("INT", "STRING",) + assert ProjectSource.RETURN_NAMES == ("sequence_number", "file_name",) + +def test_hold_config_returns_sequence_number(self): + from project_loader import ProjectSource + node = ProjectSource() + result = node.hold_config( + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=42, + label="my_source" + ) + assert result == (42, "batch_i2v") +``` + +**Step 2: Verify they now pass** + +```bash +pytest tests/test_project_loader.py::TestProjectSource -v +``` +Expected: all 4 PASS + +**Step 3: Commit** + +```bash +git add tests/test_project_loader.py +git commit -m "fix: update ProjectSource tests for file_name output" +``` + +--- + +### Task 1: Python node — `ProjectResolution` + +**Files:** +- Modify: `project_loader.py` (after the `ProjectKey` class, before `# --- Mappings ---`) +- Modify: `tests/test_project_loader.py` (add `TestProjectResolution` class) + +**Step 1: Write failing tests** + +Add this class to `tests/test_project_loader.py`: + +```python +class TestProjectResolution: + def test_input_types(self): + from project_loader import ProjectResolution + inputs = ProjectResolution.INPUT_TYPES() + assert "source_label" in inputs["required"] + assert "key_name" in inputs["required"] + assert "index" in inputs["required"] + assert inputs["required"]["index"][0] == "INT" + + def test_two_outputs(self): + from project_loader import ProjectResolution + assert ProjectResolution.RETURN_TYPES == ("INT", "INT") + assert ProjectResolution.RETURN_NAMES == ("width", "height") + + def test_fetch_resolution_basic(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [768, 1344], [1344, 768]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=1, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (768, 1344) + + def test_fetch_resolution_index_zero(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [1024, 1024]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_fetch_resolution_clamps_on_out_of_bounds(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [1024, 1024]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=99, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (1024, 1024) # last entry + + def test_fetch_resolution_missing_key_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + with patch("project_loader._fetch_data", return_value={}): + result = node.fetch_resolution( + source_label="src", key_name="nonexistent", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_fetch_resolution_network_error_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_data", return_value=error_resp): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_category(self): + from project_loader import ProjectResolution + assert ProjectResolution.CATEGORY == "utils/json/project" +``` + +**Step 2: Run tests to verify they fail** + +```bash +pytest tests/test_project_loader.py::TestProjectResolution -v +``` +Expected: `ImportError: cannot import name 'ProjectResolution'` + +**Step 3: Implement `ProjectResolution` in `project_loader.py`** + +Insert this class after `ProjectKey` (line ~294), before `# --- Mappings ---`: + +```python +class ProjectResolution: + """Fetches a (width, height) pair from a resolution series by loop index.""" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_label": ("STRING", {"default": "", "multiline": False}), + "key_name": ("STRING", {"default": "resolutions", "multiline": False}), + "index": ("INT", {"default": 0, "min": 0, "max": 9999}), + }, + "optional": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + } + + RETURN_TYPES = ("INT", "INT") + RETURN_NAMES = ("width", "height") + FUNCTION = "fetch_resolution" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("nan") + + def fetch_resolution(self, source_label, key_name, index, + manager_url="http://localhost:8080", project_name="", + file_name="", sequence_number=1): + sequence_number = int(sequence_number) + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message")) + return (512, 512) + + series = data.get(key_name) + if not isinstance(series, list) or len(series) == 0: + logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name) + return (512, 512) + + clamped = min(index, len(series) - 1) + entry = series[clamped] + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry) + return (512, 512) + + return (to_int(entry[0]), to_int(entry[1])) +``` + +**Step 4: Run tests to verify they pass** + +```bash +pytest tests/test_project_loader.py::TestProjectResolution -v +``` +Expected: all 7 tests PASS + +**Step 5: Commit** + +```bash +git add project_loader.py tests/test_project_loader.py +git commit -m "feat: add ProjectResolution node" +``` + +--- + +### Task 2: Register `ProjectResolution` in mappings + fix mapping tests + +**Files:** +- Modify: `project_loader.py` (mappings section, lines ~297-307) +- Modify: `tests/test_project_loader.py` (`TestNodeMappings` class) + +**Step 1: Update mappings in `project_loader.py`** + +Change the mappings at the bottom of the file: + +```python +PROJECT_NODE_CLASS_MAPPINGS = { + "ProjectLoaderDynamic": ProjectLoaderDynamic, + "ProjectSource": ProjectSource, + "ProjectKey": ProjectKey, + "ProjectResolution": ProjectResolution, +} + +PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { + "ProjectLoaderDynamic": "Project Loader (Dynamic)", + "ProjectSource": "Project Source", + "ProjectKey": "Project Key", + "ProjectResolution": "Project Resolution", +} +``` + +**Step 2: Update the mapping test** + +In `tests/test_project_loader.py`, update `TestNodeMappings.test_mappings_exist`: + +```python +class TestNodeMappings: + def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 4 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 +``` + +**Step 3: Run all project_loader tests** + +```bash +pytest tests/test_project_loader.py -v +``` +Expected: all tests PASS + +**Step 4: Commit** + +```bash +git add project_loader.py tests/test_project_loader.py +git commit -m "feat: register ProjectResolution in node mappings" +``` + +--- + +### Task 3: NiceGUI resolution series editor in `tab_batch_ng.py` + +**Files:** +- Modify: `tab_batch_ng.py` + +The resolution series editor goes inside `splitter.before`, directly after the "Specific Negative" textarea (currently line ~552-553). No new file needed. + +**Step 1: Add the helper function** + +Add this function near the other helper functions at the top of the render section (before `_render_sequence_card`): + +```python +def _is_resolution_series(val) -> bool: + """Return True if val is a list of [width, height] int pairs.""" + if not isinstance(val, list) or len(val) == 0: + return False + return all( + isinstance(entry, (list, tuple)) and len(entry) == 2 + and all(isinstance(v, (int, float)) for v in entry) + for entry in val + ) +``` + +Note: `Any` is intentionally omitted — `tab_batch_ng.py` does not import `typing.Any`. + +**Step 2: Add the resolution series render section** + +After the "Specific Negative" textarea in `splitter.before` (after line ~553), add: + +```python + # --- Resolution Series --- + res_keys = [k for k, v in seq.items() if _is_resolution_series(v)] + if res_keys: + ui.label('Resolution Series').classes('text-caption text-weight-bold q-mt-md') + for res_key in res_keys: + series: list = seq[res_key] + with ui.card().classes('w-full q-pa-sm q-mt-xs').props('flat bordered'): + with ui.row().classes('items-center q-mb-xs'): + ui.label(res_key).classes('text-caption col') + def del_series(k=res_key): + del seq[k] + commit() + ui.button(icon='delete', on_click=del_series).props( + 'flat dense round size=xs color=negative') + with ui.row().classes('text-caption text-grey q-mb-xs'): + ui.label('#').style('width:24px') + ui.label('Width').classes('col') + ui.label('Height').classes('col') + ui.label('').style('width:28px') + for idx, entry in enumerate(series): + with ui.row().classes('items-center w-full'): + ui.label(str(idx + 1)).classes('text-caption').style('width:24px') + w_inp = ui.number(value=int(entry[0]), min=1, step=1).classes( + 'col').props('outlined dense hide-bottom-space') + h_inp = ui.number(value=int(entry[1]), min=1, step=1).classes( + 'col').props('outlined dense hide-bottom-space') + + def _sync_wh(i=idx, k=res_key, wi=w_inp, hi=h_inp): + seq[k][i] = [ + int(wi.value) if wi.value else 512, + int(hi.value) if hi.value else 512, + ] + commit() + + w_inp.on('blur', lambda _, s=_sync_wh: s()) + h_inp.on('blur', lambda _, s=_sync_wh: s()) + + def del_row(i=idx, k=res_key): + seq[k].pop(i) + commit() + ui.button(icon='remove', on_click=del_row).props( + 'flat dense round size=xs') + + def add_row(k=res_key): + seq[k].append([512, 512]) + commit() + ui.button('+ Add row', icon='add', on_click=add_row).props( + 'flat dense size=sm').classes('q-mt-xs') + + with ui.expansion('Add Resolution Series', icon='straighten').classes('w-full q-mt-sm'): + new_res_key = ui.input('Key name', value='resolutions').props('outlined dense') + def add_res_series(): + k = new_res_key.value.strip() + if k and k not in seq: + seq[k] = [[512, 512], [1024, 1024]] + commit() + ui.button('Add', icon='add', on_click=add_res_series).props('outlined dense') +``` + +**Step 3: Run all tests** + +```bash +pytest tests/ -q +``` +Expected: all tests PASS (no Python tests cover the NiceGUI render path, but no regressions) + +**Important:** Also update the `custom_keys` filter in `_render_sequence_card` (line ~648) to exclude +resolution series keys — otherwise they'd render in both the resolution editor AND "Custom Parameters": + +```python +# Find this line: +custom_keys = [k for k in seq.keys() if k not in standard_keys] +# Replace with: +custom_keys = [k for k in seq.keys() if k not in standard_keys and not _is_resolution_series(seq.get(k))] +``` + +**Step 4: Commit** + +```bash +git add tab_batch_ng.py +git commit -m "feat: resolution series editor in sequence card" +``` + +--- + +### Task 4: JS extension `web/project_resolution.js` + +**Files:** +- Create: `web/project_resolution.js` + +This file mirrors `web/project_key.js` exactly, with two differences: +1. It targets `"ProjectResolution"` instead of `"ProjectKey"` +2. `_refreshKeys` filters to only show keys whose value is a resolution series (list of `[int, int]` pairs) — but since the keys API only returns key names (not values), the filter is done by naming convention or we just show all keys and let the user pick. For simplicity, show all keys (same as ProjectKey) and let the user pick. +3. The `index` widget is **not** hidden — the user wires it from a loop node +4. The node has two outputs (`width`, `height`) so no output slot name update is needed + +**Step 1: Create `web/project_resolution.js`** + +```javascript +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.resolution", + + async beforeQueuePrompt() { + if (!app.graph?._nodes) return; + for (const node of app.graph._nodes) { + if (node.type === "ProjectResolution" && node._syncFromSource) { + node._syncFromSource(); + } + } + }, + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectResolution") return; + + function hideWidget(widget) { + if (widget.origType === undefined) widget.origType = widget.type; + widget.type = "hidden"; + widget.hidden = true; + widget.computeSize = () => [0, -4]; + } + + function replaceWithCombo(node, name, values, callback) { + const idx = node.widgets?.findIndex(w => w.name === name); + if (idx === -1 || idx === undefined) return null; + const oldWidget = node.widgets[idx]; + const savedValue = oldWidget.value || ""; + const comboValues = values.length > 0 ? values : [""]; + if (savedValue && !comboValues.includes(savedValue)) { + comboValues.unshift(savedValue); + } + const defaultValue = savedValue || comboValues[0]; + node.widgets.splice(idx, 1); + const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues }); + if (node.widgets.length > 1) { + node.widgets.splice(node.widgets.length - 1, 1); + node.widgets.splice(idx, 0, combo); + } + return combo; + } + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + this._configured = false; + + // Hide synced config widgets (index stays visible — user wires it) + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const node = this; + const sourceLabels = this._getSourceLabels?.() || []; + const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + if (srcCombo) srcCombo.value = sourceLabels[0] || ""; + + const keyCombo = replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + if (keyCombo) keyCombo.value = ""; + + queueMicrotask(() => { + if (!this._configured) { + this.setSize(this.computeSize()); + } + }); + }; + + nodeType.prototype._getSourceLabels = function () { + const seen = new Set(); + const labels = []; + if (!this.graph) return labels; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value && !seen.has(lw.value)) { + seen.add(lw.value); + labels.push(lw.value); + } + } + } + return labels; + }; + + nodeType.prototype._findSource = function (label) { + if (!this.graph || !label) return null; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value === label) return node; + } + } + return null; + }; + + nodeType.prototype._syncFromSource = function () { + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + const source = this._findSource(srcWidget?.value); + if (!source) return; + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const dst = this.widgets?.find(w => w.name === name); + const src = source.widgets?.find(w => w.name === name); + if (dst && src) dst.value = src.value; + } + }; + + nodeType.prototype._refreshKeys = async function () { + const urlW = this.widgets?.find(w => w.name === "manager_url"); + const projW = this.widgets?.find(w => w.name === "project_name"); + const fileW = this.widgets?.find(w => w.name === "file_name"); + const seqW = this.widgets?.find(w => w.name === "sequence_number"); + if (!urlW?.value || !projW?.value || !fileW?.value) return; + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}` + ); + if (!resp.ok) return; + const data = await resp.json(); + if (data.error || !Array.isArray(data.keys)) return; + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget) { + keyWidget.options.values = data.keys.length > 0 ? data.keys : [""]; + } + } catch (e) { + console.error("[ProjectResolution] Failed to refresh keys:", e); + } + }; + + const origOnMouseDown = nodeType.prototype.onMouseDown; + nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) { + origOnMouseDown?.apply(this, arguments); + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget) srcWidget.options.values = this._getSourceLabels(); + this._syncFromSource(); + }; + + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + this._configured = true; + + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget && srcWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + } else if (srcWidget) { + srcWidget.options.values = this._getSourceLabels(); + } + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget && keyWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + } + + const finalKeyWidget = this.widgets?.find(w => w.name === "key_name"); + if (finalKeyWidget?.value) { + this.title = `Resolution: ${finalKeyWidget.value}`; + } + + this.setSize(this.computeSize()); + + const node = this; + queueMicrotask(() => { + node._syncFromSource(); + node._refreshKeys(); + }); + }; + }, +}); +``` + +**Step 2: Run all tests** + +```bash +pytest tests/ -q +``` +Expected: all tests PASS (JS has no Python tests) + +**Step 3: Commit** + +```bash +git add web/project_resolution.js +git commit -m "feat: ProjectResolution JS extension for ComfyUI frontend" +``` + +--- + +### Task 5: Final verification and push + +**Step 1: Run full test suite** + +```bash +pytest tests/ -v +``` +Expected: all tests PASS + +**Step 2: Push branch** + +```bash +git push origin HEAD +``` diff --git a/main.py b/main.py index d0ba139..3ce2e93 100644 --- a/main.py +++ b/main.py @@ -295,12 +295,12 @@ def index(): sync_to_db, pane_state.db, pane_state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) pane_state.data_cache = data pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 pane_state.loaded_file = str(fp) @@ -339,13 +339,13 @@ def index(): sync_to_db, state.db, state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) # Strip snapshot data from history_tree_backup to prevent RAM/disk bloat for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) state.data_cache = data state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 state.loaded_file = str(fp) diff --git a/snapshot_timeline.py b/snapshot_timeline.py new file mode 100644 index 0000000..2a8e312 --- /dev/null +++ b/snapshot_timeline.py @@ -0,0 +1,184 @@ +import time +import uuid +from typing import Any + +KEY_PROMPT_HISTORY = "prompt_history" + + +class SnapshotTimeline: + """Flat chronological snapshot list — replaces the old HistoryTree DAG.""" + + def __init__(self, raw_data: dict[str, Any]) -> None: + # Detect and migrate old HistoryTree format + if "nodes" in raw_data and "branches" in raw_data: + self._migrate_from_tree(raw_data) + elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list): + self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY]) + else: + self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {}) + self.current_id: str | None = raw_data.get("current_id", None) + + # ------------------------------------------------------------------ + # Migration + # ------------------------------------------------------------------ + + def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None: + """Flatten old HistoryTree nodes into snapshot list, discarding DAG info.""" + self.snapshots = {} + nodes = raw_data.get("nodes", {}) + for nid, node in nodes.items(): + self.snapshots[nid] = { + "id": nid, + "timestamp": node.get("timestamp", time.time()), + "note": node.get("note", "Migrated"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(node.get("data")), + } + # Preserve snapshot data if present + if "data" in node and node["data"]: + self.snapshots[nid]["data"] = node["data"] + self.current_id = raw_data.get("head_id") + + def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None: + """Convert ancient prompt_history list into snapshots.""" + self.snapshots = {} + self.current_id = None + for item in reversed(old_list): + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": item.get("note", "Legacy Import"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(item), + "data": item, + } + self.current_id = sid + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def record(self, data: dict[str, Any], note: str = "Snapshot", + auto: bool = False) -> str: + """Create a new snapshot and return its ID.""" + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": note, + "pinned": False, + "auto": auto, + "seq_count": self._count_seqs(data), + "data": data, + } + self.current_id = sid + return sid + + def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None: + """Return the inline snapshot data if present.""" + snap = self.snapshots.get(snapshot_id) + if snap: + return snap.get("data") + return None + + def toggle_pin(self, snapshot_id: str) -> bool: + """Toggle pinned state, return new value.""" + snap = self.snapshots.get(snapshot_id) + if snap: + snap["pinned"] = not snap.get("pinned", False) + return snap["pinned"] + return False + + def delete(self, snapshot_id: str) -> None: + """Remove a snapshot.""" + self.snapshots.pop(snapshot_id, None) + if self.current_id == snapshot_id: + # Fall back to most recent remaining + if self.snapshots: + self.current_id = max( + self.snapshots.values(), key=lambda s: s["timestamp"] + )["id"] + else: + self.current_id = None + + def strip_snapshots(self) -> None: + """Remove inline data from all snapshots (for slim JSON storage).""" + for snap in self.snapshots.values(): + snap.pop("data", None) + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + return { + "snapshots": self.snapshots, + "current_id": self.current_id, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_id(self) -> str: + for _ in range(10): + sid = str(uuid.uuid4())[:8] + if sid not in self.snapshots: + return sid + raise ValueError("Failed to generate unique snapshot ID after 10 attempts") + + @staticmethod + def _count_seqs(data: dict | None) -> int: + if not data: + return 0 + from utils import KEY_BATCH_DATA + batch = data.get(KEY_BATCH_DATA, []) + return len(batch) if isinstance(batch, list) else 0 + + +# ------------------------------------------------------------------ +# Diff function +# ------------------------------------------------------------------ + +def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]: + """Compare two batch lists by sequence_number, return per-sequence diffs. + + Returns a list of dicts: + { + "seq_num": int, + "status": "unchanged" | "changed" | "added" | "removed", + "changes": [{"field": str, "old": Any, "new": Any}], + } + """ + from utils import KEY_SEQUENCE_NUMBER + + old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch} + new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch} + + all_seqs = sorted(set(old_by_seq) | set(new_by_seq)) + result = [] + + for seq_num in all_seqs: + old_item = old_by_seq.get(seq_num) + new_item = new_by_seq.get(seq_num) + + if old_item and not new_item: + result.append({"seq_num": seq_num, "status": "removed", "changes": []}) + elif new_item and not old_item: + result.append({"seq_num": seq_num, "status": "added", "changes": []}) + else: + # Both exist — field-by-field comparison + all_keys = sorted(set(old_item) | set(new_item)) + changes = [] + for k in all_keys: + old_val = old_item.get(k) + new_val = new_item.get(k) + if old_val != new_val: + changes.append({"field": k, "old": old_val, "new": new_val}) + status = "changed" if changes else "unchanged" + result.append({"seq_num": seq_num, "status": status, "changes": changes}) + + return result diff --git a/state.py b/state.py index 4f8d7a4..86f236a 100644 --- a/state.py +++ b/state.py @@ -13,7 +13,7 @@ class AppState: snippets: dict = field(default_factory=dict) file_path: Path | None = None restored_indicator: str | None = None - timeline_selected_nodes: set = field(default_factory=set) + timeline_selected_id: str | None = None live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 7bf4a59..c0bd53c 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -16,9 +16,11 @@ from utils import ( DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'} +_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots +_last_auto_snap: dict[str, float] = {} # file_path -> timestamp SUB_SEGMENT_MULTIPLIER = 1000 SUB_SEGMENT_NUM_COLORS = 6 FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip'] @@ -86,18 +88,18 @@ def find_insert_position(batch_list, parent_index, parent_seq_num): # --- Auto change note --- -def _auto_change_note(htree, batch_list, state=None, file_path=None): +def _auto_change_note(timeline, batch_list, state=None, file_path=None): """Compare current batch_list against last snapshot and describe changes.""" - # Get previous batch data from the current head - if not htree.head_id or htree.head_id not in htree.nodes: + # Get previous batch data from the current snapshot + if not timeline.current_id or timeline.current_id not in timeline.snapshots: return f'Initial save ({len(batch_list)} sequences)' - # Load previous snapshot from DB (nodes no longer hold data in memory) - prev_data = htree.nodes[htree.head_id].get('data') + # Load previous snapshot from inline data or DB + prev_data = timeline.get_snapshot_data(timeline.current_id) if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path: df = state.db.get_data_file_by_names(state.current_project, file_path.stem) if df: - prev_data = state.db.get_node_snapshot(df['id'], htree.head_id) + prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id) prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, []) prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch} @@ -363,38 +365,34 @@ def render_batch_processor(state: AppState): logger.info("save_and_snap START") data[KEY_BATCH_DATA] = batch_list tree_data = data.get(KEY_HISTORY_TREE, {}) - htree = HistoryTree(tree_data) - note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list, state=state, file_path=file_path) + timeline = SnapshotTimeline(tree_data) + note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path) # Single serialization: json roundtrip gives us an isolated snapshot - # without the expensive deepcopy t1 = time.perf_counter() snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot_payload = json.loads(snapshot_json) logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1) try: - htree.commit(snapshot_payload, note=note) + timeline.record(snapshot_payload, note=note) except ValueError as e: ui.notify(f'Save failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - # DB path: sync full tree (with snapshots) to DB, then - # write slim tree (no snapshots) to JSON and memory - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree t1 = time.perf_counter() db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1) else: - # No DB: write full tree (with snapshots) to JSON - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) @@ -416,9 +414,30 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, refresh_list): async def commit(message=None): data[KEY_BATCH_DATA] = batch_list + # Auto-snapshot with debounce + fp_key = str(file_path) + now = time.time() + did_snap = False + if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE: + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) + snap_json = json.dumps({k: v for k, v in data.items() + if k != KEY_HISTORY_TREE}) + snap_payload = json.loads(snap_json) + try: + timeline.record(snap_payload, note=message or "Auto-save", auto=True) + if state.db_enabled and state.current_project and state.db: + data[KEY_HISTORY_TREE] = timeline.to_dict() + db_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap) + timeline.strip_snapshots() + did_snap = True + data[KEY_HISTORY_TREE] = timeline.to_dict() + _last_auto_snap[fp_key] = now + except ValueError: + pass # Non-critical: skip auto-snapshot on ID collision snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: + if state.db_enabled and state.current_project and state.db and not did_snap: await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) if message: ui.notify(message, type='positive') @@ -845,26 +864,26 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li batch_list[idx][key] = copy.deepcopy(source_seq.get(key)) data[KEY_BATCH_DATA] = batch_list - htree = HistoryTree(data.get(KEY_HISTORY_TREE, {})) + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot = json.loads(snapshot_json) try: - htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") + timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}") except ValueError as e: ui.notify(f'Mass update failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) else: - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) ui.notify(f'Updated {len(targets)} sequences', type='positive') diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index 7e89b40..3986a2e 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -1,5 +1,5 @@ import asyncio -import hashlib +import copy import json import logging import time @@ -7,373 +7,15 @@ import time from nicegui import ui from state import AppState -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline, diff_snapshots from utils import save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE logger = logging.getLogger(__name__) -def _delete_nodes(htree, data, file_path, node_ids, state=None): - """Delete nodes with backup, branch cleanup, re-parenting, and head fallback.""" - if 'history_tree_backup' not in data: - data['history_tree_backup'] = [] - # Back up tree metadata only (no snapshot data) to avoid bloating JSON - backup = json.loads(json.dumps(htree.to_dict())) - for node in backup.get('nodes', {}).values(): - node.pop('data', None) - data['history_tree_backup'].append(backup) - data['history_tree_backup'] = data['history_tree_backup'][-10:] - # Save deleted node parents before removal (needed for branch re-pointing) - deleted_parents = {} - for nid in node_ids: - deleted_node = htree.nodes.get(nid) - if deleted_node: - deleted_parents[nid] = deleted_node.get('parent') - # Re-parent children of deleted nodes — walk up to find a surviving ancestor - for nid in node_ids: - surviving_parent = deleted_parents.get(nid) - while surviving_parent in node_ids: - surviving_parent = deleted_parents.get(surviving_parent) - for child in htree.nodes.values(): - if child.get('parent') == nid: - child['parent'] = surviving_parent - for nid in node_ids: - htree.nodes.pop(nid, None) - # Re-point branches whose tip was deleted to a surviving ancestor - for b, tip in list(htree.branches.items()): - if tip in node_ids: - new_tip = deleted_parents.get(tip) - while new_tip in node_ids: - new_tip = deleted_parents.get(new_tip) - if new_tip and new_tip in htree.nodes: - htree.branches[b] = new_tip - else: - del htree.branches[b] - if htree.head_id in node_ids: - if htree.nodes: - htree.head_id = sorted(htree.nodes.values(), - key=lambda x: x['timestamp'])[-1]['id'] - else: - htree.head_id = None - data[KEY_HISTORY_TREE] = htree.to_dict() - # Clean up DB snapshots for deleted nodes - if state and state.db_enabled and state.db and state.current_project: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - state.db.delete_node_snapshots(df['id'], set(node_ids)) - - -def _render_selection_picker(all_nodes, htree, state, refresh_fn): - """Multi-select picker for batch-deleting timeline nodes.""" - all_ids = [n['id'] for n in all_nodes] - - def fmt_option(nid): - n = htree.nodes[nid] - ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp'])) - note = n.get('note', 'Step') - head = ' (HEAD)' if nid == htree.head_id else '' - return f'{note} - {ts} ({nid[:6]}){head}' - - options = {nid: fmt_option(nid) for nid in all_ids} - - def on_selection_change(e): - state.timeline_selected_nodes = set(e.value) if e.value else set() - - ui.select( - options, - value=list(state.timeline_selected_nodes), - multiple=True, - label='Select nodes to delete:', - on_change=on_selection_change, - ).classes('w-full') - - with ui.row(): - def select_all(): - state.timeline_selected_nodes = set(all_ids) - refresh_fn() - def deselect_all(): - state.timeline_selected_nodes = set() - refresh_fn() - ui.button('Select All', on_click=select_all).props('flat dense') - ui.button('Deselect All', on_click=deselect_all).props('flat dense') - - -def _render_graph_or_log(mode, all_nodes, htree, selected_nodes, - selection_mode_on, toggle_select_fn, restore_fn, - selected=None): - """Render graph visualization or linear log view.""" - if mode in ('Horizontal', 'Vertical'): - direction = 'LR' if mode == 'Horizontal' else 'TB' - with ui.card().classes('w-full q-pa-md'): - try: - graph_dot = htree.generate_graph(direction=direction) - sel_id = selected.get('node_id') if selected else None - _render_graphviz(graph_dot, selected_node_id=sel_id) - except Exception as e: - ui.label(f'Graph Error: {e}').classes('text-negative') - - elif mode == 'Linear Log': - ui.label('Chronological list of all snapshots.').classes('text-caption') - for n in all_nodes: - is_head = n['id'] == htree.head_id - is_selected = n['id'] in selected_nodes - - card_style = '' - if is_selected: - card_style = 'background: rgba(239, 68, 68, 0.1) !important; border-left: 3px solid var(--negative);' - elif is_head: - card_style = 'background: var(--accent-subtle) !important; border-left: 3px solid var(--accent);' - with ui.card().classes('w-full q-mb-sm').style(card_style): - with ui.row().classes('w-full items-center'): - if selection_mode_on: - ui.checkbox( - '', - value=is_selected, - on_change=lambda e, nid=n['id']: toggle_select_fn( - nid, e.value), - ) - - icon = 'location_on' if is_head else 'circle' - ui.icon(icon).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label = f'{note} (Current)' if is_head else note - ui.label(label).classes('text-bold') - ui.label( - f'ID: {n["id"][:6]} - {ts}').classes('text-caption') - - if not is_head and not selection_mode_on: - ui.button( - 'Restore', - icon='restore', - on_click=lambda node=n: restore_fn(node), - ).props('flat dense color=primary') - - -def _render_batch_delete(htree, data, file_path, state, refresh_fn): - """Render batch delete controls for selected timeline nodes.""" - valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - state.timeline_selected_nodes = valid - count = len(valid) - if count == 0: - return - - ui.label( - f'{count} node{"s" if count != 1 else ""} selected for deletion.' - ).classes('text-warning q-mt-md') - - async def do_batch_delete(): - current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - _delete_nodes(htree, data, file_path, current_valid, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - state.timeline_selected_nodes = set() - ui.notify( - f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', - type='positive') - refresh_fn() - - ui.button( - f'Delete {count} Node{"s" if count != 1 else ""}', - icon='delete', - on_click=do_batch_delete, - ).props('color=negative') - - -def _walk_branch_nodes(htree, tip_id): - """Walk parent pointers from tip, returning nodes newest-first.""" - nodes = [] - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - visited.add(current) - nodes.append(htree.nodes[current]) - current = htree.nodes[current].get('parent') - return nodes - - -def _find_active_branch(htree): - """Return branch name whose tip == head_id, or None if detached.""" - if not htree.head_id: - return None - for b_name, tip_id in htree.branches.items(): - if tip_id == htree.head_id: - return b_name - return None - - -def _find_branch_for_node(htree, node_id): - """Return the branch name whose ancestry contains node_id, or None.""" - for b_name, tip_id in htree.branches.items(): - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - if current == node_id: - return b_name - visited.add(current) - current = htree.nodes[current].get('parent') - return None - - -def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected, state=None): - """Render branch-grouped node manager with restore, rename, delete, and preview.""" - ui.label('Manage Version').classes('section-header') - - active_branch = _find_active_branch(htree) - - # --- (a) Branch selector --- - def fmt_branch(b_name): - count = len(_walk_branch_nodes(htree, htree.branches.get(b_name))) - suffix = ' (active)' if b_name == active_branch else '' - return f'{b_name} ({count} nodes){suffix}' - - branch_options = {b: fmt_branch(b) for b in htree.branches} - - def on_branch_change(e): - selected['branch'] = e.value - tip = htree.branches.get(e.value) - if tip: - selected['node_id'] = tip - render_branch_nodes.refresh() - - ui.select( - branch_options, - value=selected['branch'], - label='Branch:', - on_change=on_branch_change, - ).classes('w-full') - - # --- (b) Node list + (c) Actions panel --- - @ui.refreshable - def render_branch_nodes(): - branch_name = selected['branch'] - tip_id = htree.branches.get(branch_name) - nodes = _walk_branch_nodes(htree, tip_id) if tip_id else [] - - if not nodes: - ui.label('No nodes on this branch.').classes('text-caption q-pa-sm') - return - - with ui.scroll_area().classes('w-full').style('max-height: 350px'): - for n in nodes: - nid = n['id'] - is_head = nid == htree.head_id - is_tip = nid == tip_id - is_selected = nid == selected['node_id'] - - card_style = '' - if is_selected: - card_style = 'border-left: 3px solid var(--primary);' - elif is_head: - card_style = 'border-left: 3px solid var(--accent);' - - with ui.card().classes('w-full q-mb-xs q-pa-xs').style(card_style): - with ui.row().classes('w-full items-center no-wrap'): - icon = 'location_on' if is_head else 'circle' - icon_size = 'sm' if is_head else 'xs' - ui.icon(icon, size=icon_size).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col q-ml-xs').style('min-width: 0'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label_text = note - lbl = ui.label(label_text).classes('text-body2 ellipsis') - if is_head: - lbl.classes('text-bold') - ui.label(f'{ts} \u2022 {nid[:6]}').classes( - 'text-caption text-grey') - - if is_head: - ui.badge('HEAD', color='amber').props('dense') - if is_tip and not is_head: - ui.badge('tip', color='green', outline=True).props('dense') - - def select_node(node_id=nid): - selected['node_id'] = node_id - render_branch_nodes.refresh() - - ui.button(icon='check_circle', on_click=select_node).props( - 'flat dense round size=sm' - ).tooltip('Select this node') - - # --- (c) Actions panel --- - sel_id = selected['node_id'] - if not sel_id or sel_id not in htree.nodes: - return - - sel_node = htree.nodes[sel_id] - sel_note = sel_node.get('note', 'Step') - is_head = sel_id == htree.head_id - - ui.separator().classes('q-my-sm') - ui.label(f'Selected: {sel_note} ({sel_id[:6]})').classes( - 'text-caption text-bold') - - with ui.row().classes('w-full items-end q-gutter-sm'): - if not is_head: - def restore_selected(): - if sel_id in htree.nodes: - restore_fn(htree.nodes[sel_id]) - ui.button('Restore', icon='restore', - on_click=restore_selected).props('color=primary dense') - - # Rename - rename_input = ui.input('Rename Label').classes('col').props('dense') - - async def rename_node(): - if sel_id in htree.nodes and rename_input.value: - htree.nodes[sel_id]['note'] = rename_input.value - data[KEY_HISTORY_TREE] = htree.to_dict() - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - ui.notify('Label updated', type='positive') - refresh_fn() - - ui.button('Update Label', on_click=rename_node).props('flat dense') - - # Danger zone - with ui.expansion('Danger Zone', icon='warning').classes( - 'w-full q-mt-sm').style('border-left: 3px solid var(--negative)'): - ui.label('Deleting a node cannot be undone.').classes('text-warning') - - async def delete_selected(): - if sel_id in htree.nodes: - _delete_nodes(htree, data, file_path, {sel_id}, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - # Reset selection if branch was removed - if selected['branch'] not in htree.branches: - selected['branch'] = next(iter(htree.branches), None) - selected['node_id'] = htree.head_id - ui.notify('Node Deleted', type='positive') - refresh_fn() - - ui.button('Delete This Node', icon='delete', - on_click=delete_selected).props('color=negative dense') - - # Data preview - with ui.expansion('Data Preview', icon='preview').classes('w-full q-mt-sm'): - _render_data_preview(sel_id, htree, state=state, file_path=file_path) - - render_branch_nodes() - +# ====================================================================== +# Main entry point +# ====================================================================== def render_timeline_tab(state: AppState): t0 = time.perf_counter() @@ -383,266 +25,541 @@ def render_timeline_tab(state: AppState): tree_data = data.get(KEY_HISTORY_TREE, {}) if not tree_data: - ui.label('No history timeline exists. Make some changes in the Editor first!').classes( + ui.label('No version history exists. Make some changes in the Editor first!').classes( 'text-subtitle1 q-pa-md') return - htree = HistoryTree(tree_data) + timeline = SnapshotTimeline(tree_data) + if not timeline.snapshots: + ui.label('No snapshots found in history.').classes('text-subtitle1 q-pa-md') + return - # --- Shared selected-node state (survives refreshes, shared by graph + manager) --- - active_branch = _find_active_branch(htree) - default_branch = active_branch - if not default_branch and htree.head_id: - for b_name, tip_id in htree.branches.items(): - for n in _walk_branch_nodes(htree, tip_id): - if n['id'] == htree.head_id: - default_branch = b_name - break - if default_branch: - break - if not default_branch and htree.branches: - default_branch = next(iter(htree.branches)) - selected = {'node_id': htree.head_id, 'branch': default_branch} + # Local UI state + ui_state = { + 'selected_id': state.timeline_selected_id or timeline.current_id, + 'search': '', + 'filter': 'All', # All | Pinned | Auto + } if state.restored_indicator: ui.label(f'Editing Restored Version: {state.restored_indicator}').classes( 'text-info q-pa-sm') - # --- View mode + Selection toggle --- - with ui.row().classes('w-full items-center q-gutter-md q-mb-md'): - ui.label('Version History').classes('text-h6 col') - view_mode = ui.toggle( - ['Horizontal', 'Vertical', 'Linear Log'], - value='Horizontal', - ) - selection_mode = ui.switch('Select to Delete') + ui.label('Version History').classes('text-h6 q-mb-sm') - @ui.refreshable - def render_timeline(): - t_rt = time.perf_counter() - logger.info("render_timeline START (%d nodes)", len(htree.nodes)) - all_nodes = sorted(htree.nodes.values(), key=lambda x: x['timestamp'], reverse=True) - selected_nodes = state.timeline_selected_nodes if selection_mode.value else set() + # Mutable container so left/right panels can cross-reference each other's refreshables + panels: dict = {} - if selection_mode.value: - _render_selection_picker(all_nodes, htree, state, render_timeline.refresh) + # ====================================================================== + # Splitter layout: 35% left (list) / 65% right (detail) + # ====================================================================== + with ui.splitter(value=35).classes('w-full').style('height: calc(100vh - 200px); min-height: 600px') as splitter: - _render_graph_or_log( - view_mode.value, all_nodes, htree, selected_nodes, - selection_mode.value, _toggle_select, _restore_and_refresh, - selected=selected) + # ============================================================== + # LEFT PANEL — Snapshot list + # ============================================================== + with splitter.before: + with ui.column().classes('w-full q-pa-sm').style('height: 100%'): + # Search + filter + search_input = ui.input( + placeholder='Search notes...', + ).classes('w-full').props('dense outlined clearable') - if selection_mode.value and state.timeline_selected_nodes: - _render_batch_delete(htree, data, file_path, state, render_timeline.refresh) + with ui.row().classes('w-full q-gutter-xs'): + filter_toggle = ui.toggle( + ['All', 'Pinned', 'Auto'], value='All', + ).props('dense no-caps') - with ui.card().classes('w-full q-pa-md q-mt-md'): - _render_node_manager( - all_nodes, htree, data, file_path, - _restore_and_refresh, render_timeline.refresh, - selected, state=state) - logger.info("render_timeline END (%.3fs)", time.perf_counter() - t_rt) + @ui.refreshable + def render_snapshot_list(): + _render_snapshot_list( + timeline, ui_state, data, file_path, state, + render_snapshot_list, panels) - def _toggle_select(nid, checked): - if checked: - state.timeline_selected_nodes.add(nid) - else: - state.timeline_selected_nodes.discard(nid) - render_timeline.refresh() + panels['list'] = render_snapshot_list - async def _restore_and_refresh(node): - await _restore_node(data, node, htree, file_path, state) - # Refresh all tabs (batch, raw, timeline) so they pick up the restored data - state._render_main.refresh() + def _on_search(e): + ui_state['search'] = search_input.value or '' + render_snapshot_list.refresh() + + def _on_filter(e): + ui_state['filter'] = e.value + render_snapshot_list.refresh() + + search_input.on('update:model-value', _on_search) + filter_toggle.on_value_change(_on_filter) + + render_snapshot_list() + + # ============================================================== + # RIGHT PANEL — Detail tabs + # ============================================================== + with splitter.after: + @ui.refreshable + def render_detail_panel(): + _render_detail_panel(timeline, ui_state, data, file_path, state, + panels) + + panels['detail'] = render_detail_panel + render_detail_panel() - view_mode.on_value_change(lambda _: render_timeline.refresh()) - selection_mode.on_value_change(lambda _: render_timeline.refresh()) - render_timeline() logger.info("render_timeline_tab END (%.3fs)", time.perf_counter() - t0) - # --- Poll for graph node clicks (JS → Python bridge) --- - graph_timer = None - async def _poll_graph_click(): - if view_mode.value == 'Linear Log': - return - try: - result = await ui.run_javascript( - 'const v = window.graphSelectedNode;' - 'window.graphSelectedNode = null; v;' - ) - except Exception: - # Deactivate timer if parent slot was deleted - if graph_timer is not None: - graph_timer.active = False - return - if not result: - return - node_id = str(result) - if node_id not in htree.nodes: - return - branch = _find_branch_for_node(htree, node_id) - if branch: - selected['branch'] = branch - selected['node_id'] = node_id - render_timeline.refresh() +# ====================================================================== +# Left panel: snapshot list +# ====================================================================== - graph_timer = ui.timer(0.5, _poll_graph_click) +def _render_snapshot_list(timeline, ui_state, data, file_path, state, + refresh_list, panels): + snapshots = sorted(timeline.snapshots.values(), + key=lambda s: s['timestamp'], reverse=True) - def _cleanup_timer(): - if graph_timer is not None: - graph_timer.active = False - ui.context.client.on_disconnect(_cleanup_timer) + # Apply filters + search_term = ui_state.get('search', '').lower() + filter_mode = ui_state.get('filter', 'All') + if search_term: + snapshots = [s for s in snapshots + if search_term in s.get('note', '').lower()] + if filter_mode == 'Pinned': + snapshots = [s for s in snapshots if s.get('pinned')] + elif filter_mode == 'Auto': + snapshots = [s for s in snapshots if s.get('auto')] -_graphviz_svg_cache: dict[str, str] = {} -_GRAPHVIZ_CACHE_MAX = 20 - - -def _render_graphviz(dot_source: str, selected_node_id: str | None = None): - """Render graphviz DOT source as interactive SVG with click-to-select.""" - try: - import graphviz - t_gv = time.perf_counter() - cache_key = hashlib.md5(dot_source.encode()).hexdigest() - svg = _graphviz_svg_cache.get(cache_key) - if svg is None: - src = graphviz.Source(dot_source) - svg = src.pipe(format='svg').decode('utf-8') - if len(_graphviz_svg_cache) >= _GRAPHVIZ_CACHE_MAX: - _graphviz_svg_cache.pop(next(iter(_graphviz_svg_cache))) - _graphviz_svg_cache[cache_key] = svg - logger.info("_render_graphviz MISS (generated): %.3fs", time.perf_counter() - t_gv) - else: - logger.info("_render_graphviz HIT (cached): %.3fs", time.perf_counter() - t_gv) - - sel_escaped = json.dumps(selected_node_id or '')[1:-1] # strip quotes, get JS-safe content - - # CSS inline (allowed), JS via run_javascript (script tags blocked) - css = '''''' - - ui.html( - f'{css}