diff --git a/db.py b/db.py index c870e8f..81e9461 100644 --- a/db.py +++ b/db.py @@ -242,7 +242,6 @@ class ProjectDB: ) self.conn.commit() - @staticmethod @staticmethod def _migrate_lora_keys(data: dict) -> dict: """Split combined lora 'name:strength' into separate name and strength keys.""" @@ -340,27 +339,34 @@ class ProjectDB: # ------------------------------------------------------------------ def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: - """Save history tree, extracting node snapshots into separate table.""" + """Save history tree, extracting snapshot data into separate table. + + Supports both new format (snapshots dict) and old format (nodes dict). + """ now = time.time() - nodes = tree_data.get("nodes", {}) + if "snapshots" in tree_data: + entries = tree_data.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = tree_data.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(tree_data) - slim_nodes = {} - for nid, node in nodes.items(): - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries = {} + for eid, entry in entries.items(): + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute("BEGIN IMMEDIATE") try: - # Extract snapshot data from nodes into history_snapshots table - for nid, node in nodes.items(): - snap = node.get("data") + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (data_file_id, nid, json.dumps(snap), now), + (data_file_id, eid, json.dumps(snap), now), ) self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " @@ -463,24 +469,30 @@ class ProjectDB: ) # Import history tree (extract snapshots into separate table) + # Supports both new format (snapshots dict) and old format (nodes dict) history_tree = data.get(KEY_HISTORY_TREE) if history_tree and isinstance(history_tree, dict): now = time.time() - nodes = history_tree.get("nodes", {}) + if "snapshots" in history_tree: + entries = history_tree.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = history_tree.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(history_tree) - slim_nodes = {} - for nid, node in nodes.items(): - snap = node.get("data") + slim_entries = {} + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (df_id, nid, json.dumps(snap), now), + (df_id, eid, json.dumps(snap), now), ) - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " "VALUES (?, ?, ?) " @@ -540,9 +552,9 @@ class ProjectDB: # Load history tree (metadata only, no snapshot data) tree = self.get_history_tree(df["id"]) if tree: - # Strip any residual snapshot data from nodes - for node in tree.get("nodes", {}).values(): - node.pop("data", None) + # Strip any residual snapshot data (supports both formats) + for entry in tree.get("snapshots", tree.get("nodes", {})).values(): + entry.pop("data", None) data["history_tree"] = tree t3 = time.time() diff --git a/docs/plans/2026-04-03-resolution-series-design.md b/docs/plans/2026-04-03-resolution-series-design.md new file mode 100644 index 0000000..683d591 --- /dev/null +++ b/docs/plans/2026-04-03-resolution-series-design.md @@ -0,0 +1,81 @@ +# Resolution Series Design + +## Problem + +When running ComfyUI loop nodes for multi-step upscaling (e.g. 3+ resolutions at different sizes), +managing portrait vs landscape width/height per iteration is tedious. Users need a structured way +to define N resolution pairs in the manager UI and retrieve them by loop index in ComfyUI. + +## Design + +### Data Model + +Resolution series are stored as a JSON array under a user-chosen key in the sequence data: + +```json +"upscale_resolutions": [[512, 512], [768, 1344], [1344, 768], [2048, 2048]] +``` + +- Each element is `[width, height]` (both INT) +- Key name is chosen by the user (any string) +- Number of entries is configurable (add/remove rows) +- Stored in the same project JSON file and sequence — no schema change required +- Index out of bounds → clamp to last entry + +### NiceGUI UI (tab_batch_ng.py) + +A resolution series editor is rendered in the left column of the sequence card, directly below +the "Specific Negative" textarea. + +Layout: + +``` +── Resolution Series ────────────────── + key name: [upscale_resolutions ] + # Width Height + 1 [2048] [2048] [x] + 2 [768 ] [1344] [x] + 3 [1344] [768 ] [x] + [+ Add row] +``` + +- Key name is editable (defaults to `resolutions`) +- Rows added/removed inline; each change calls `commit()` immediately +- Hidden behind an "Add Resolution Series" button when no resolution key exists yet +- A value is detected as a resolution series if it is a list of `[int, int]` pairs + +### ComfyUI Node (`ProjectResolution`) + +New node class in `project_loader.py`, sibling to `ProjectKey`. + +**Inputs:** +- `source_label` (STRING) — references a `ProjectSource` by label +- `key_name` (STRING) — the resolution series key name +- `index` (INT, min 0) — wired from loop node's current index output +- `manager_url`, `project_name`, `file_name`, `sequence_number` — optional, synced from `ProjectSource` via JS + +**Outputs:** `width` (INT), `height` (INT) + +**Execution:** fetches the sequence data, reads `data[key_name]`, indexes into the array with +clamp-to-last on out-of-bounds, returns `(width, height)`. + +**JS (`web/project_resolution.js`):** +- Same `_syncFromSource` mechanism as `project_key.js` +- `key_name` widget is replaced with a combo dropdown populated with keys whose value is a + resolution series (list of `[int, int]` pairs), detected via the existing keys API +- Registered in `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS` + +### API + +No new endpoints. Uses existing: +- `/json_manager/get_project_keys` — for key discovery (JS combo population) +- `_fetch_data()` — for execution-time data fetch + +### Files Changed + +| File | Change | +|------|--------| +| `project_loader.py` | Add `ProjectResolution` class + register in mappings | +| `web/project_resolution.js` | New JS extension for the node | +| `tab_batch_ng.py` | Resolution series editor below Specific Negative | +| `__init__.py` | Register new JS file if needed | diff --git a/docs/plans/2026-04-03-resolution-series-plan.md b/docs/plans/2026-04-03-resolution-series-plan.md new file mode 100644 index 0000000..9e2cc9d --- /dev/null +++ b/docs/plans/2026-04-03-resolution-series-plan.md @@ -0,0 +1,640 @@ +# Resolution Series Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add a `ProjectResolution` ComfyUI node and NiceGUI editor that let users define N `(width, height)` pairs per sequence and retrieve them by loop index. + +**Architecture:** Resolution series are stored as a JSON array of `[width, height]` pairs under a user-chosen key in sequence data (e.g. `"upscale_resolutions": [[512,512],[768,1344]]`). A new `ProjectResolution` ComfyUI node (sibling of `ProjectKey`) accepts a `source_label`, `key_name`, and `index` INT from a loop node, and returns `width` + `height`. The NiceGUI sequence card gets an inline table editor placed directly below the "Specific Negative" textarea. + +**Tech Stack:** Python (ComfyUI node), NiceGUI (UI), JavaScript (ComfyUI frontend extension), pytest + +**Branch:** Create and work on `feat/resolution-series` branched from `main`: +```bash +git checkout main && git checkout -b feat/resolution-series +``` + +--- + +### Task 0: Fix pre-existing test failures on `main` + +When `file_name` was added as a second output to `ProjectSource`, two tests were not updated. +They fail on `main` before any new code is written. + +**Files:** +- Modify: `tests/test_project_loader.py` (`TestProjectSource` class, lines ~216-231) + +**Step 1: Update the two broken tests** + +```python +def test_outputs_sequence_number(self): + from project_loader import ProjectSource + assert ProjectSource.RETURN_TYPES == ("INT", "STRING",) + assert ProjectSource.RETURN_NAMES == ("sequence_number", "file_name",) + +def test_hold_config_returns_sequence_number(self): + from project_loader import ProjectSource + node = ProjectSource() + result = node.hold_config( + manager_url="http://localhost:8080", + project_name="proj1", + file_name="batch_i2v", + sequence_number=42, + label="my_source" + ) + assert result == (42, "batch_i2v") +``` + +**Step 2: Verify they now pass** + +```bash +pytest tests/test_project_loader.py::TestProjectSource -v +``` +Expected: all 4 PASS + +**Step 3: Commit** + +```bash +git add tests/test_project_loader.py +git commit -m "fix: update ProjectSource tests for file_name output" +``` + +--- + +### Task 1: Python node — `ProjectResolution` + +**Files:** +- Modify: `project_loader.py` (after the `ProjectKey` class, before `# --- Mappings ---`) +- Modify: `tests/test_project_loader.py` (add `TestProjectResolution` class) + +**Step 1: Write failing tests** + +Add this class to `tests/test_project_loader.py`: + +```python +class TestProjectResolution: + def test_input_types(self): + from project_loader import ProjectResolution + inputs = ProjectResolution.INPUT_TYPES() + assert "source_label" in inputs["required"] + assert "key_name" in inputs["required"] + assert "index" in inputs["required"] + assert inputs["required"]["index"][0] == "INT" + + def test_two_outputs(self): + from project_loader import ProjectResolution + assert ProjectResolution.RETURN_TYPES == ("INT", "INT") + assert ProjectResolution.RETURN_NAMES == ("width", "height") + + def test_fetch_resolution_basic(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [768, 1344], [1344, 768]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=1, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (768, 1344) + + def test_fetch_resolution_index_zero(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [1024, 1024]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_fetch_resolution_clamps_on_out_of_bounds(self): + from project_loader import ProjectResolution + node = ProjectResolution() + data = {"resolutions": [[512, 512], [1024, 1024]]} + with patch("project_loader._fetch_data", return_value=data): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=99, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (1024, 1024) # last entry + + def test_fetch_resolution_missing_key_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + with patch("project_loader._fetch_data", return_value={}): + result = node.fetch_resolution( + source_label="src", key_name="nonexistent", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_fetch_resolution_network_error_returns_defaults(self): + from project_loader import ProjectResolution + node = ProjectResolution() + error_resp = {"error": "network_error", "message": "Connection refused"} + with patch("project_loader._fetch_data", return_value=error_resp): + result = node.fetch_resolution( + source_label="src", key_name="resolutions", index=0, + manager_url="http://localhost:8080", project_name="p", + file_name="f", sequence_number=1, + ) + assert result == (512, 512) + + def test_category(self): + from project_loader import ProjectResolution + assert ProjectResolution.CATEGORY == "utils/json/project" +``` + +**Step 2: Run tests to verify they fail** + +```bash +pytest tests/test_project_loader.py::TestProjectResolution -v +``` +Expected: `ImportError: cannot import name 'ProjectResolution'` + +**Step 3: Implement `ProjectResolution` in `project_loader.py`** + +Insert this class after `ProjectKey` (line ~294), before `# --- Mappings ---`: + +```python +class ProjectResolution: + """Fetches a (width, height) pair from a resolution series by loop index.""" + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source_label": ("STRING", {"default": "", "multiline": False}), + "key_name": ("STRING", {"default": "resolutions", "multiline": False}), + "index": ("INT", {"default": 0, "min": 0, "max": 9999}), + }, + "optional": { + "manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}), + "project_name": ("STRING", {"default": "", "multiline": False}), + "file_name": ("STRING", {"default": "", "multiline": False}), + "sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}), + }, + } + + RETURN_TYPES = ("INT", "INT") + RETURN_NAMES = ("width", "height") + FUNCTION = "fetch_resolution" + CATEGORY = "utils/json/project" + OUTPUT_NODE = False + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("nan") + + def fetch_resolution(self, source_label, key_name, index, + manager_url="http://localhost:8080", project_name="", + file_name="", sequence_number=1): + sequence_number = int(sequence_number) + data = _fetch_data(manager_url, project_name, file_name, sequence_number) + if data.get("error") in ("http_error", "network_error", "parse_error"): + logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message")) + return (512, 512) + + series = data.get(key_name) + if not isinstance(series, list) or len(series) == 0: + logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name) + return (512, 512) + + clamped = min(index, len(series) - 1) + entry = series[clamped] + if not isinstance(entry, (list, tuple)) or len(entry) < 2: + logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry) + return (512, 512) + + return (to_int(entry[0]), to_int(entry[1])) +``` + +**Step 4: Run tests to verify they pass** + +```bash +pytest tests/test_project_loader.py::TestProjectResolution -v +``` +Expected: all 7 tests PASS + +**Step 5: Commit** + +```bash +git add project_loader.py tests/test_project_loader.py +git commit -m "feat: add ProjectResolution node" +``` + +--- + +### Task 2: Register `ProjectResolution` in mappings + fix mapping tests + +**Files:** +- Modify: `project_loader.py` (mappings section, lines ~297-307) +- Modify: `tests/test_project_loader.py` (`TestNodeMappings` class) + +**Step 1: Update mappings in `project_loader.py`** + +Change the mappings at the bottom of the file: + +```python +PROJECT_NODE_CLASS_MAPPINGS = { + "ProjectLoaderDynamic": ProjectLoaderDynamic, + "ProjectSource": ProjectSource, + "ProjectKey": ProjectKey, + "ProjectResolution": ProjectResolution, +} + +PROJECT_NODE_DISPLAY_NAME_MAPPINGS = { + "ProjectLoaderDynamic": "Project Loader (Dynamic)", + "ProjectSource": "Project Source", + "ProjectKey": "Project Key", + "ProjectResolution": "Project Resolution", +} +``` + +**Step 2: Update the mapping test** + +In `tests/test_project_loader.py`, update `TestNodeMappings.test_mappings_exist`: + +```python +class TestNodeMappings: + def test_mappings_exist(self): + from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS + assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS + assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS + assert len(PROJECT_NODE_CLASS_MAPPINGS) == 4 + assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4 +``` + +**Step 3: Run all project_loader tests** + +```bash +pytest tests/test_project_loader.py -v +``` +Expected: all tests PASS + +**Step 4: Commit** + +```bash +git add project_loader.py tests/test_project_loader.py +git commit -m "feat: register ProjectResolution in node mappings" +``` + +--- + +### Task 3: NiceGUI resolution series editor in `tab_batch_ng.py` + +**Files:** +- Modify: `tab_batch_ng.py` + +The resolution series editor goes inside `splitter.before`, directly after the "Specific Negative" textarea (currently line ~552-553). No new file needed. + +**Step 1: Add the helper function** + +Add this function near the other helper functions at the top of the render section (before `_render_sequence_card`): + +```python +def _is_resolution_series(val) -> bool: + """Return True if val is a list of [width, height] int pairs.""" + if not isinstance(val, list) or len(val) == 0: + return False + return all( + isinstance(entry, (list, tuple)) and len(entry) == 2 + and all(isinstance(v, (int, float)) for v in entry) + for entry in val + ) +``` + +Note: `Any` is intentionally omitted — `tab_batch_ng.py` does not import `typing.Any`. + +**Step 2: Add the resolution series render section** + +After the "Specific Negative" textarea in `splitter.before` (after line ~553), add: + +```python + # --- Resolution Series --- + res_keys = [k for k, v in seq.items() if _is_resolution_series(v)] + if res_keys: + ui.label('Resolution Series').classes('text-caption text-weight-bold q-mt-md') + for res_key in res_keys: + series: list = seq[res_key] + with ui.card().classes('w-full q-pa-sm q-mt-xs').props('flat bordered'): + with ui.row().classes('items-center q-mb-xs'): + ui.label(res_key).classes('text-caption col') + def del_series(k=res_key): + del seq[k] + commit() + ui.button(icon='delete', on_click=del_series).props( + 'flat dense round size=xs color=negative') + with ui.row().classes('text-caption text-grey q-mb-xs'): + ui.label('#').style('width:24px') + ui.label('Width').classes('col') + ui.label('Height').classes('col') + ui.label('').style('width:28px') + for idx, entry in enumerate(series): + with ui.row().classes('items-center w-full'): + ui.label(str(idx + 1)).classes('text-caption').style('width:24px') + w_inp = ui.number(value=int(entry[0]), min=1, step=1).classes( + 'col').props('outlined dense hide-bottom-space') + h_inp = ui.number(value=int(entry[1]), min=1, step=1).classes( + 'col').props('outlined dense hide-bottom-space') + + def _sync_wh(i=idx, k=res_key, wi=w_inp, hi=h_inp): + seq[k][i] = [ + int(wi.value) if wi.value else 512, + int(hi.value) if hi.value else 512, + ] + commit() + + w_inp.on('blur', lambda _, s=_sync_wh: s()) + h_inp.on('blur', lambda _, s=_sync_wh: s()) + + def del_row(i=idx, k=res_key): + seq[k].pop(i) + commit() + ui.button(icon='remove', on_click=del_row).props( + 'flat dense round size=xs') + + def add_row(k=res_key): + seq[k].append([512, 512]) + commit() + ui.button('+ Add row', icon='add', on_click=add_row).props( + 'flat dense size=sm').classes('q-mt-xs') + + with ui.expansion('Add Resolution Series', icon='straighten').classes('w-full q-mt-sm'): + new_res_key = ui.input('Key name', value='resolutions').props('outlined dense') + def add_res_series(): + k = new_res_key.value.strip() + if k and k not in seq: + seq[k] = [[512, 512], [1024, 1024]] + commit() + ui.button('Add', icon='add', on_click=add_res_series).props('outlined dense') +``` + +**Step 3: Run all tests** + +```bash +pytest tests/ -q +``` +Expected: all tests PASS (no Python tests cover the NiceGUI render path, but no regressions) + +**Important:** Also update the `custom_keys` filter in `_render_sequence_card` (line ~648) to exclude +resolution series keys — otherwise they'd render in both the resolution editor AND "Custom Parameters": + +```python +# Find this line: +custom_keys = [k for k in seq.keys() if k not in standard_keys] +# Replace with: +custom_keys = [k for k in seq.keys() if k not in standard_keys and not _is_resolution_series(seq.get(k))] +``` + +**Step 4: Commit** + +```bash +git add tab_batch_ng.py +git commit -m "feat: resolution series editor in sequence card" +``` + +--- + +### Task 4: JS extension `web/project_resolution.js` + +**Files:** +- Create: `web/project_resolution.js` + +This file mirrors `web/project_key.js` exactly, with two differences: +1. It targets `"ProjectResolution"` instead of `"ProjectKey"` +2. `_refreshKeys` filters to only show keys whose value is a resolution series (list of `[int, int]` pairs) — but since the keys API only returns key names (not values), the filter is done by naming convention or we just show all keys and let the user pick. For simplicity, show all keys (same as ProjectKey) and let the user pick. +3. The `index` widget is **not** hidden — the user wires it from a loop node +4. The node has two outputs (`width`, `height`) so no output slot name update is needed + +**Step 1: Create `web/project_resolution.js`** + +```javascript +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +app.registerExtension({ + name: "json.manager.project.resolution", + + async beforeQueuePrompt() { + if (!app.graph?._nodes) return; + for (const node of app.graph._nodes) { + if (node.type === "ProjectResolution" && node._syncFromSource) { + node._syncFromSource(); + } + } + }, + + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name !== "ProjectResolution") return; + + function hideWidget(widget) { + if (widget.origType === undefined) widget.origType = widget.type; + widget.type = "hidden"; + widget.hidden = true; + widget.computeSize = () => [0, -4]; + } + + function replaceWithCombo(node, name, values, callback) { + const idx = node.widgets?.findIndex(w => w.name === name); + if (idx === -1 || idx === undefined) return null; + const oldWidget = node.widgets[idx]; + const savedValue = oldWidget.value || ""; + const comboValues = values.length > 0 ? values : [""]; + if (savedValue && !comboValues.includes(savedValue)) { + comboValues.unshift(savedValue); + } + const defaultValue = savedValue || comboValues[0]; + node.widgets.splice(idx, 1); + const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues }); + if (node.widgets.length > 1) { + node.widgets.splice(node.widgets.length - 1, 1); + node.widgets.splice(idx, 0, combo); + } + return combo; + } + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + origOnNodeCreated?.apply(this, arguments); + this._configured = false; + + // Hide synced config widgets (index stays visible — user wires it) + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const node = this; + const sourceLabels = this._getSourceLabels?.() || []; + const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + if (srcCombo) srcCombo.value = sourceLabels[0] || ""; + + const keyCombo = replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + if (keyCombo) keyCombo.value = ""; + + queueMicrotask(() => { + if (!this._configured) { + this.setSize(this.computeSize()); + } + }); + }; + + nodeType.prototype._getSourceLabels = function () { + const seen = new Set(); + const labels = []; + if (!this.graph) return labels; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value && !seen.has(lw.value)) { + seen.add(lw.value); + labels.push(lw.value); + } + } + } + return labels; + }; + + nodeType.prototype._findSource = function (label) { + if (!this.graph || !label) return null; + for (const node of this.graph._nodes) { + if (node.type === "ProjectSource") { + const lw = node.widgets?.find(w => w.name === "label"); + if (lw?.value === label) return node; + } + } + return null; + }; + + nodeType.prototype._syncFromSource = function () { + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + const source = this._findSource(srcWidget?.value); + if (!source) return; + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const dst = this.widgets?.find(w => w.name === name); + const src = source.widgets?.find(w => w.name === name); + if (dst && src) dst.value = src.value; + } + }; + + nodeType.prototype._refreshKeys = async function () { + const urlW = this.widgets?.find(w => w.name === "manager_url"); + const projW = this.widgets?.find(w => w.name === "project_name"); + const fileW = this.widgets?.find(w => w.name === "file_name"); + const seqW = this.widgets?.find(w => w.name === "sequence_number"); + if (!urlW?.value || !projW?.value || !fileW?.value) return; + + try { + const resp = await api.fetchApi( + `/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}` + ); + if (!resp.ok) return; + const data = await resp.json(); + if (data.error || !Array.isArray(data.keys)) return; + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget) { + keyWidget.options.values = data.keys.length > 0 ? data.keys : [""]; + } + } catch (e) { + console.error("[ProjectResolution] Failed to refresh keys:", e); + } + }; + + const origOnMouseDown = nodeType.prototype.onMouseDown; + nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) { + origOnMouseDown?.apply(this, arguments); + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget) srcWidget.options.values = this._getSourceLabels(); + this._syncFromSource(); + }; + + const origOnConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function (info) { + origOnConfigure?.apply(this, arguments); + this._configured = true; + + for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) { + const w = this.widgets?.find(w => w.name === name); + if (w) hideWidget(w); + } + + const srcWidget = this.widgets?.find(w => w.name === "source_label"); + if (srcWidget && srcWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) { + node._syncFromSource(); + node._refreshKeys(); + }); + } else if (srcWidget) { + srcWidget.options.values = this._getSourceLabels(); + } + + const keyWidget = this.widgets?.find(w => w.name === "key_name"); + if (keyWidget && keyWidget.type !== "combo") { + const node = this; + replaceWithCombo(this, "key_name", [], function (value) { + node.title = value ? `Resolution: ${value}` : "Project Resolution"; + app.graph?.setDirtyCanvas(true, true); + }); + } + + const finalKeyWidget = this.widgets?.find(w => w.name === "key_name"); + if (finalKeyWidget?.value) { + this.title = `Resolution: ${finalKeyWidget.value}`; + } + + this.setSize(this.computeSize()); + + const node = this; + queueMicrotask(() => { + node._syncFromSource(); + node._refreshKeys(); + }); + }; + }, +}); +``` + +**Step 2: Run all tests** + +```bash +pytest tests/ -q +``` +Expected: all tests PASS (JS has no Python tests) + +**Step 3: Commit** + +```bash +git add web/project_resolution.js +git commit -m "feat: ProjectResolution JS extension for ComfyUI frontend" +``` + +--- + +### Task 5: Final verification and push + +**Step 1: Run full test suite** + +```bash +pytest tests/ -v +``` +Expected: all tests PASS + +**Step 2: Push branch** + +```bash +git push origin HEAD +``` diff --git a/main.py b/main.py index d0ba139..3ce2e93 100644 --- a/main.py +++ b/main.py @@ -295,12 +295,12 @@ def index(): sync_to_db, pane_state.db, pane_state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) pane_state.data_cache = data pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 pane_state.loaded_file = str(fp) @@ -339,13 +339,13 @@ def index(): sync_to_db, state.db, state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) # Strip snapshot data from history_tree_backup to prevent RAM/disk bloat for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) state.data_cache = data state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 state.loaded_file = str(fp) diff --git a/snapshot_timeline.py b/snapshot_timeline.py new file mode 100644 index 0000000..2a8e312 --- /dev/null +++ b/snapshot_timeline.py @@ -0,0 +1,184 @@ +import time +import uuid +from typing import Any + +KEY_PROMPT_HISTORY = "prompt_history" + + +class SnapshotTimeline: + """Flat chronological snapshot list — replaces the old HistoryTree DAG.""" + + def __init__(self, raw_data: dict[str, Any]) -> None: + # Detect and migrate old HistoryTree format + if "nodes" in raw_data and "branches" in raw_data: + self._migrate_from_tree(raw_data) + elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list): + self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY]) + else: + self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {}) + self.current_id: str | None = raw_data.get("current_id", None) + + # ------------------------------------------------------------------ + # Migration + # ------------------------------------------------------------------ + + def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None: + """Flatten old HistoryTree nodes into snapshot list, discarding DAG info.""" + self.snapshots = {} + nodes = raw_data.get("nodes", {}) + for nid, node in nodes.items(): + self.snapshots[nid] = { + "id": nid, + "timestamp": node.get("timestamp", time.time()), + "note": node.get("note", "Migrated"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(node.get("data")), + } + # Preserve snapshot data if present + if "data" in node and node["data"]: + self.snapshots[nid]["data"] = node["data"] + self.current_id = raw_data.get("head_id") + + def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None: + """Convert ancient prompt_history list into snapshots.""" + self.snapshots = {} + self.current_id = None + for item in reversed(old_list): + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": item.get("note", "Legacy Import"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(item), + "data": item, + } + self.current_id = sid + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def record(self, data: dict[str, Any], note: str = "Snapshot", + auto: bool = False) -> str: + """Create a new snapshot and return its ID.""" + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": note, + "pinned": False, + "auto": auto, + "seq_count": self._count_seqs(data), + "data": data, + } + self.current_id = sid + return sid + + def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None: + """Return the inline snapshot data if present.""" + snap = self.snapshots.get(snapshot_id) + if snap: + return snap.get("data") + return None + + def toggle_pin(self, snapshot_id: str) -> bool: + """Toggle pinned state, return new value.""" + snap = self.snapshots.get(snapshot_id) + if snap: + snap["pinned"] = not snap.get("pinned", False) + return snap["pinned"] + return False + + def delete(self, snapshot_id: str) -> None: + """Remove a snapshot.""" + self.snapshots.pop(snapshot_id, None) + if self.current_id == snapshot_id: + # Fall back to most recent remaining + if self.snapshots: + self.current_id = max( + self.snapshots.values(), key=lambda s: s["timestamp"] + )["id"] + else: + self.current_id = None + + def strip_snapshots(self) -> None: + """Remove inline data from all snapshots (for slim JSON storage).""" + for snap in self.snapshots.values(): + snap.pop("data", None) + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + return { + "snapshots": self.snapshots, + "current_id": self.current_id, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_id(self) -> str: + for _ in range(10): + sid = str(uuid.uuid4())[:8] + if sid not in self.snapshots: + return sid + raise ValueError("Failed to generate unique snapshot ID after 10 attempts") + + @staticmethod + def _count_seqs(data: dict | None) -> int: + if not data: + return 0 + from utils import KEY_BATCH_DATA + batch = data.get(KEY_BATCH_DATA, []) + return len(batch) if isinstance(batch, list) else 0 + + +# ------------------------------------------------------------------ +# Diff function +# ------------------------------------------------------------------ + +def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]: + """Compare two batch lists by sequence_number, return per-sequence diffs. + + Returns a list of dicts: + { + "seq_num": int, + "status": "unchanged" | "changed" | "added" | "removed", + "changes": [{"field": str, "old": Any, "new": Any}], + } + """ + from utils import KEY_SEQUENCE_NUMBER + + old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch} + new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch} + + all_seqs = sorted(set(old_by_seq) | set(new_by_seq)) + result = [] + + for seq_num in all_seqs: + old_item = old_by_seq.get(seq_num) + new_item = new_by_seq.get(seq_num) + + if old_item and not new_item: + result.append({"seq_num": seq_num, "status": "removed", "changes": []}) + elif new_item and not old_item: + result.append({"seq_num": seq_num, "status": "added", "changes": []}) + else: + # Both exist — field-by-field comparison + all_keys = sorted(set(old_item) | set(new_item)) + changes = [] + for k in all_keys: + old_val = old_item.get(k) + new_val = new_item.get(k) + if old_val != new_val: + changes.append({"field": k, "old": old_val, "new": new_val}) + status = "changed" if changes else "unchanged" + result.append({"seq_num": seq_num, "status": status, "changes": changes}) + + return result diff --git a/state.py b/state.py index 4f8d7a4..86f236a 100644 --- a/state.py +++ b/state.py @@ -13,7 +13,7 @@ class AppState: snippets: dict = field(default_factory=dict) file_path: Path | None = None restored_indicator: str | None = None - timeline_selected_nodes: set = field(default_factory=set) + timeline_selected_id: str | None = None live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 7bf4a59..c0bd53c 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -16,9 +16,11 @@ from utils import ( DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'} +_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots +_last_auto_snap: dict[str, float] = {} # file_path -> timestamp SUB_SEGMENT_MULTIPLIER = 1000 SUB_SEGMENT_NUM_COLORS = 6 FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip'] @@ -86,18 +88,18 @@ def find_insert_position(batch_list, parent_index, parent_seq_num): # --- Auto change note --- -def _auto_change_note(htree, batch_list, state=None, file_path=None): +def _auto_change_note(timeline, batch_list, state=None, file_path=None): """Compare current batch_list against last snapshot and describe changes.""" - # Get previous batch data from the current head - if not htree.head_id or htree.head_id not in htree.nodes: + # Get previous batch data from the current snapshot + if not timeline.current_id or timeline.current_id not in timeline.snapshots: return f'Initial save ({len(batch_list)} sequences)' - # Load previous snapshot from DB (nodes no longer hold data in memory) - prev_data = htree.nodes[htree.head_id].get('data') + # Load previous snapshot from inline data or DB + prev_data = timeline.get_snapshot_data(timeline.current_id) if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path: df = state.db.get_data_file_by_names(state.current_project, file_path.stem) if df: - prev_data = state.db.get_node_snapshot(df['id'], htree.head_id) + prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id) prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, []) prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch} @@ -363,38 +365,34 @@ def render_batch_processor(state: AppState): logger.info("save_and_snap START") data[KEY_BATCH_DATA] = batch_list tree_data = data.get(KEY_HISTORY_TREE, {}) - htree = HistoryTree(tree_data) - note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list, state=state, file_path=file_path) + timeline = SnapshotTimeline(tree_data) + note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path) # Single serialization: json roundtrip gives us an isolated snapshot - # without the expensive deepcopy t1 = time.perf_counter() snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot_payload = json.loads(snapshot_json) logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1) try: - htree.commit(snapshot_payload, note=note) + timeline.record(snapshot_payload, note=note) except ValueError as e: ui.notify(f'Save failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - # DB path: sync full tree (with snapshots) to DB, then - # write slim tree (no snapshots) to JSON and memory - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree t1 = time.perf_counter() db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1) else: - # No DB: write full tree (with snapshots) to JSON - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) @@ -416,9 +414,30 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, refresh_list): async def commit(message=None): data[KEY_BATCH_DATA] = batch_list + # Auto-snapshot with debounce + fp_key = str(file_path) + now = time.time() + did_snap = False + if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE: + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) + snap_json = json.dumps({k: v for k, v in data.items() + if k != KEY_HISTORY_TREE}) + snap_payload = json.loads(snap_json) + try: + timeline.record(snap_payload, note=message or "Auto-save", auto=True) + if state.db_enabled and state.current_project and state.db: + data[KEY_HISTORY_TREE] = timeline.to_dict() + db_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap) + timeline.strip_snapshots() + did_snap = True + data[KEY_HISTORY_TREE] = timeline.to_dict() + _last_auto_snap[fp_key] = now + except ValueError: + pass # Non-critical: skip auto-snapshot on ID collision snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: + if state.db_enabled and state.current_project and state.db and not did_snap: await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) if message: ui.notify(message, type='positive') @@ -845,26 +864,26 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li batch_list[idx][key] = copy.deepcopy(source_seq.get(key)) data[KEY_BATCH_DATA] = batch_list - htree = HistoryTree(data.get(KEY_HISTORY_TREE, {})) + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot = json.loads(snapshot_json) try: - htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") + timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}") except ValueError as e: ui.notify(f'Mass update failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) else: - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) ui.notify(f'Updated {len(targets)} sequences', type='positive') diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index 7e89b40..3986a2e 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -1,5 +1,5 @@ import asyncio -import hashlib +import copy import json import logging import time @@ -7,373 +7,15 @@ import time from nicegui import ui from state import AppState -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline, diff_snapshots from utils import save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE logger = logging.getLogger(__name__) -def _delete_nodes(htree, data, file_path, node_ids, state=None): - """Delete nodes with backup, branch cleanup, re-parenting, and head fallback.""" - if 'history_tree_backup' not in data: - data['history_tree_backup'] = [] - # Back up tree metadata only (no snapshot data) to avoid bloating JSON - backup = json.loads(json.dumps(htree.to_dict())) - for node in backup.get('nodes', {}).values(): - node.pop('data', None) - data['history_tree_backup'].append(backup) - data['history_tree_backup'] = data['history_tree_backup'][-10:] - # Save deleted node parents before removal (needed for branch re-pointing) - deleted_parents = {} - for nid in node_ids: - deleted_node = htree.nodes.get(nid) - if deleted_node: - deleted_parents[nid] = deleted_node.get('parent') - # Re-parent children of deleted nodes — walk up to find a surviving ancestor - for nid in node_ids: - surviving_parent = deleted_parents.get(nid) - while surviving_parent in node_ids: - surviving_parent = deleted_parents.get(surviving_parent) - for child in htree.nodes.values(): - if child.get('parent') == nid: - child['parent'] = surviving_parent - for nid in node_ids: - htree.nodes.pop(nid, None) - # Re-point branches whose tip was deleted to a surviving ancestor - for b, tip in list(htree.branches.items()): - if tip in node_ids: - new_tip = deleted_parents.get(tip) - while new_tip in node_ids: - new_tip = deleted_parents.get(new_tip) - if new_tip and new_tip in htree.nodes: - htree.branches[b] = new_tip - else: - del htree.branches[b] - if htree.head_id in node_ids: - if htree.nodes: - htree.head_id = sorted(htree.nodes.values(), - key=lambda x: x['timestamp'])[-1]['id'] - else: - htree.head_id = None - data[KEY_HISTORY_TREE] = htree.to_dict() - # Clean up DB snapshots for deleted nodes - if state and state.db_enabled and state.db and state.current_project: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - state.db.delete_node_snapshots(df['id'], set(node_ids)) - - -def _render_selection_picker(all_nodes, htree, state, refresh_fn): - """Multi-select picker for batch-deleting timeline nodes.""" - all_ids = [n['id'] for n in all_nodes] - - def fmt_option(nid): - n = htree.nodes[nid] - ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp'])) - note = n.get('note', 'Step') - head = ' (HEAD)' if nid == htree.head_id else '' - return f'{note} - {ts} ({nid[:6]}){head}' - - options = {nid: fmt_option(nid) for nid in all_ids} - - def on_selection_change(e): - state.timeline_selected_nodes = set(e.value) if e.value else set() - - ui.select( - options, - value=list(state.timeline_selected_nodes), - multiple=True, - label='Select nodes to delete:', - on_change=on_selection_change, - ).classes('w-full') - - with ui.row(): - def select_all(): - state.timeline_selected_nodes = set(all_ids) - refresh_fn() - def deselect_all(): - state.timeline_selected_nodes = set() - refresh_fn() - ui.button('Select All', on_click=select_all).props('flat dense') - ui.button('Deselect All', on_click=deselect_all).props('flat dense') - - -def _render_graph_or_log(mode, all_nodes, htree, selected_nodes, - selection_mode_on, toggle_select_fn, restore_fn, - selected=None): - """Render graph visualization or linear log view.""" - if mode in ('Horizontal', 'Vertical'): - direction = 'LR' if mode == 'Horizontal' else 'TB' - with ui.card().classes('w-full q-pa-md'): - try: - graph_dot = htree.generate_graph(direction=direction) - sel_id = selected.get('node_id') if selected else None - _render_graphviz(graph_dot, selected_node_id=sel_id) - except Exception as e: - ui.label(f'Graph Error: {e}').classes('text-negative') - - elif mode == 'Linear Log': - ui.label('Chronological list of all snapshots.').classes('text-caption') - for n in all_nodes: - is_head = n['id'] == htree.head_id - is_selected = n['id'] in selected_nodes - - card_style = '' - if is_selected: - card_style = 'background: rgba(239, 68, 68, 0.1) !important; border-left: 3px solid var(--negative);' - elif is_head: - card_style = 'background: var(--accent-subtle) !important; border-left: 3px solid var(--accent);' - with ui.card().classes('w-full q-mb-sm').style(card_style): - with ui.row().classes('w-full items-center'): - if selection_mode_on: - ui.checkbox( - '', - value=is_selected, - on_change=lambda e, nid=n['id']: toggle_select_fn( - nid, e.value), - ) - - icon = 'location_on' if is_head else 'circle' - ui.icon(icon).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label = f'{note} (Current)' if is_head else note - ui.label(label).classes('text-bold') - ui.label( - f'ID: {n["id"][:6]} - {ts}').classes('text-caption') - - if not is_head and not selection_mode_on: - ui.button( - 'Restore', - icon='restore', - on_click=lambda node=n: restore_fn(node), - ).props('flat dense color=primary') - - -def _render_batch_delete(htree, data, file_path, state, refresh_fn): - """Render batch delete controls for selected timeline nodes.""" - valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - state.timeline_selected_nodes = valid - count = len(valid) - if count == 0: - return - - ui.label( - f'{count} node{"s" if count != 1 else ""} selected for deletion.' - ).classes('text-warning q-mt-md') - - async def do_batch_delete(): - current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - _delete_nodes(htree, data, file_path, current_valid, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - state.timeline_selected_nodes = set() - ui.notify( - f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', - type='positive') - refresh_fn() - - ui.button( - f'Delete {count} Node{"s" if count != 1 else ""}', - icon='delete', - on_click=do_batch_delete, - ).props('color=negative') - - -def _walk_branch_nodes(htree, tip_id): - """Walk parent pointers from tip, returning nodes newest-first.""" - nodes = [] - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - visited.add(current) - nodes.append(htree.nodes[current]) - current = htree.nodes[current].get('parent') - return nodes - - -def _find_active_branch(htree): - """Return branch name whose tip == head_id, or None if detached.""" - if not htree.head_id: - return None - for b_name, tip_id in htree.branches.items(): - if tip_id == htree.head_id: - return b_name - return None - - -def _find_branch_for_node(htree, node_id): - """Return the branch name whose ancestry contains node_id, or None.""" - for b_name, tip_id in htree.branches.items(): - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - if current == node_id: - return b_name - visited.add(current) - current = htree.nodes[current].get('parent') - return None - - -def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected, state=None): - """Render branch-grouped node manager with restore, rename, delete, and preview.""" - ui.label('Manage Version').classes('section-header') - - active_branch = _find_active_branch(htree) - - # --- (a) Branch selector --- - def fmt_branch(b_name): - count = len(_walk_branch_nodes(htree, htree.branches.get(b_name))) - suffix = ' (active)' if b_name == active_branch else '' - return f'{b_name} ({count} nodes){suffix}' - - branch_options = {b: fmt_branch(b) for b in htree.branches} - - def on_branch_change(e): - selected['branch'] = e.value - tip = htree.branches.get(e.value) - if tip: - selected['node_id'] = tip - render_branch_nodes.refresh() - - ui.select( - branch_options, - value=selected['branch'], - label='Branch:', - on_change=on_branch_change, - ).classes('w-full') - - # --- (b) Node list + (c) Actions panel --- - @ui.refreshable - def render_branch_nodes(): - branch_name = selected['branch'] - tip_id = htree.branches.get(branch_name) - nodes = _walk_branch_nodes(htree, tip_id) if tip_id else [] - - if not nodes: - ui.label('No nodes on this branch.').classes('text-caption q-pa-sm') - return - - with ui.scroll_area().classes('w-full').style('max-height: 350px'): - for n in nodes: - nid = n['id'] - is_head = nid == htree.head_id - is_tip = nid == tip_id - is_selected = nid == selected['node_id'] - - card_style = '' - if is_selected: - card_style = 'border-left: 3px solid var(--primary);' - elif is_head: - card_style = 'border-left: 3px solid var(--accent);' - - with ui.card().classes('w-full q-mb-xs q-pa-xs').style(card_style): - with ui.row().classes('w-full items-center no-wrap'): - icon = 'location_on' if is_head else 'circle' - icon_size = 'sm' if is_head else 'xs' - ui.icon(icon, size=icon_size).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col q-ml-xs').style('min-width: 0'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label_text = note - lbl = ui.label(label_text).classes('text-body2 ellipsis') - if is_head: - lbl.classes('text-bold') - ui.label(f'{ts} \u2022 {nid[:6]}').classes( - 'text-caption text-grey') - - if is_head: - ui.badge('HEAD', color='amber').props('dense') - if is_tip and not is_head: - ui.badge('tip', color='green', outline=True).props('dense') - - def select_node(node_id=nid): - selected['node_id'] = node_id - render_branch_nodes.refresh() - - ui.button(icon='check_circle', on_click=select_node).props( - 'flat dense round size=sm' - ).tooltip('Select this node') - - # --- (c) Actions panel --- - sel_id = selected['node_id'] - if not sel_id or sel_id not in htree.nodes: - return - - sel_node = htree.nodes[sel_id] - sel_note = sel_node.get('note', 'Step') - is_head = sel_id == htree.head_id - - ui.separator().classes('q-my-sm') - ui.label(f'Selected: {sel_note} ({sel_id[:6]})').classes( - 'text-caption text-bold') - - with ui.row().classes('w-full items-end q-gutter-sm'): - if not is_head: - def restore_selected(): - if sel_id in htree.nodes: - restore_fn(htree.nodes[sel_id]) - ui.button('Restore', icon='restore', - on_click=restore_selected).props('color=primary dense') - - # Rename - rename_input = ui.input('Rename Label').classes('col').props('dense') - - async def rename_node(): - if sel_id in htree.nodes and rename_input.value: - htree.nodes[sel_id]['note'] = rename_input.value - data[KEY_HISTORY_TREE] = htree.to_dict() - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - ui.notify('Label updated', type='positive') - refresh_fn() - - ui.button('Update Label', on_click=rename_node).props('flat dense') - - # Danger zone - with ui.expansion('Danger Zone', icon='warning').classes( - 'w-full q-mt-sm').style('border-left: 3px solid var(--negative)'): - ui.label('Deleting a node cannot be undone.').classes('text-warning') - - async def delete_selected(): - if sel_id in htree.nodes: - _delete_nodes(htree, data, file_path, {sel_id}, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - # Reset selection if branch was removed - if selected['branch'] not in htree.branches: - selected['branch'] = next(iter(htree.branches), None) - selected['node_id'] = htree.head_id - ui.notify('Node Deleted', type='positive') - refresh_fn() - - ui.button('Delete This Node', icon='delete', - on_click=delete_selected).props('color=negative dense') - - # Data preview - with ui.expansion('Data Preview', icon='preview').classes('w-full q-mt-sm'): - _render_data_preview(sel_id, htree, state=state, file_path=file_path) - - render_branch_nodes() - +# ====================================================================== +# Main entry point +# ====================================================================== def render_timeline_tab(state: AppState): t0 = time.perf_counter() @@ -383,266 +25,541 @@ def render_timeline_tab(state: AppState): tree_data = data.get(KEY_HISTORY_TREE, {}) if not tree_data: - ui.label('No history timeline exists. Make some changes in the Editor first!').classes( + ui.label('No version history exists. Make some changes in the Editor first!').classes( 'text-subtitle1 q-pa-md') return - htree = HistoryTree(tree_data) + timeline = SnapshotTimeline(tree_data) + if not timeline.snapshots: + ui.label('No snapshots found in history.').classes('text-subtitle1 q-pa-md') + return - # --- Shared selected-node state (survives refreshes, shared by graph + manager) --- - active_branch = _find_active_branch(htree) - default_branch = active_branch - if not default_branch and htree.head_id: - for b_name, tip_id in htree.branches.items(): - for n in _walk_branch_nodes(htree, tip_id): - if n['id'] == htree.head_id: - default_branch = b_name - break - if default_branch: - break - if not default_branch and htree.branches: - default_branch = next(iter(htree.branches)) - selected = {'node_id': htree.head_id, 'branch': default_branch} + # Local UI state + ui_state = { + 'selected_id': state.timeline_selected_id or timeline.current_id, + 'search': '', + 'filter': 'All', # All | Pinned | Auto + } if state.restored_indicator: ui.label(f'Editing Restored Version: {state.restored_indicator}').classes( 'text-info q-pa-sm') - # --- View mode + Selection toggle --- - with ui.row().classes('w-full items-center q-gutter-md q-mb-md'): - ui.label('Version History').classes('text-h6 col') - view_mode = ui.toggle( - ['Horizontal', 'Vertical', 'Linear Log'], - value='Horizontal', - ) - selection_mode = ui.switch('Select to Delete') + ui.label('Version History').classes('text-h6 q-mb-sm') - @ui.refreshable - def render_timeline(): - t_rt = time.perf_counter() - logger.info("render_timeline START (%d nodes)", len(htree.nodes)) - all_nodes = sorted(htree.nodes.values(), key=lambda x: x['timestamp'], reverse=True) - selected_nodes = state.timeline_selected_nodes if selection_mode.value else set() + # Mutable container so left/right panels can cross-reference each other's refreshables + panels: dict = {} - if selection_mode.value: - _render_selection_picker(all_nodes, htree, state, render_timeline.refresh) + # ====================================================================== + # Splitter layout: 35% left (list) / 65% right (detail) + # ====================================================================== + with ui.splitter(value=35).classes('w-full').style('height: calc(100vh - 200px); min-height: 600px') as splitter: - _render_graph_or_log( - view_mode.value, all_nodes, htree, selected_nodes, - selection_mode.value, _toggle_select, _restore_and_refresh, - selected=selected) + # ============================================================== + # LEFT PANEL — Snapshot list + # ============================================================== + with splitter.before: + with ui.column().classes('w-full q-pa-sm').style('height: 100%'): + # Search + filter + search_input = ui.input( + placeholder='Search notes...', + ).classes('w-full').props('dense outlined clearable') - if selection_mode.value and state.timeline_selected_nodes: - _render_batch_delete(htree, data, file_path, state, render_timeline.refresh) + with ui.row().classes('w-full q-gutter-xs'): + filter_toggle = ui.toggle( + ['All', 'Pinned', 'Auto'], value='All', + ).props('dense no-caps') - with ui.card().classes('w-full q-pa-md q-mt-md'): - _render_node_manager( - all_nodes, htree, data, file_path, - _restore_and_refresh, render_timeline.refresh, - selected, state=state) - logger.info("render_timeline END (%.3fs)", time.perf_counter() - t_rt) + @ui.refreshable + def render_snapshot_list(): + _render_snapshot_list( + timeline, ui_state, data, file_path, state, + render_snapshot_list, panels) - def _toggle_select(nid, checked): - if checked: - state.timeline_selected_nodes.add(nid) - else: - state.timeline_selected_nodes.discard(nid) - render_timeline.refresh() + panels['list'] = render_snapshot_list - async def _restore_and_refresh(node): - await _restore_node(data, node, htree, file_path, state) - # Refresh all tabs (batch, raw, timeline) so they pick up the restored data - state._render_main.refresh() + def _on_search(e): + ui_state['search'] = search_input.value or '' + render_snapshot_list.refresh() + + def _on_filter(e): + ui_state['filter'] = e.value + render_snapshot_list.refresh() + + search_input.on('update:model-value', _on_search) + filter_toggle.on_value_change(_on_filter) + + render_snapshot_list() + + # ============================================================== + # RIGHT PANEL — Detail tabs + # ============================================================== + with splitter.after: + @ui.refreshable + def render_detail_panel(): + _render_detail_panel(timeline, ui_state, data, file_path, state, + panels) + + panels['detail'] = render_detail_panel + render_detail_panel() - view_mode.on_value_change(lambda _: render_timeline.refresh()) - selection_mode.on_value_change(lambda _: render_timeline.refresh()) - render_timeline() logger.info("render_timeline_tab END (%.3fs)", time.perf_counter() - t0) - # --- Poll for graph node clicks (JS → Python bridge) --- - graph_timer = None - async def _poll_graph_click(): - if view_mode.value == 'Linear Log': - return - try: - result = await ui.run_javascript( - 'const v = window.graphSelectedNode;' - 'window.graphSelectedNode = null; v;' - ) - except Exception: - # Deactivate timer if parent slot was deleted - if graph_timer is not None: - graph_timer.active = False - return - if not result: - return - node_id = str(result) - if node_id not in htree.nodes: - return - branch = _find_branch_for_node(htree, node_id) - if branch: - selected['branch'] = branch - selected['node_id'] = node_id - render_timeline.refresh() +# ====================================================================== +# Left panel: snapshot list +# ====================================================================== - graph_timer = ui.timer(0.5, _poll_graph_click) +def _render_snapshot_list(timeline, ui_state, data, file_path, state, + refresh_list, panels): + snapshots = sorted(timeline.snapshots.values(), + key=lambda s: s['timestamp'], reverse=True) - def _cleanup_timer(): - if graph_timer is not None: - graph_timer.active = False - ui.context.client.on_disconnect(_cleanup_timer) + # Apply filters + search_term = ui_state.get('search', '').lower() + filter_mode = ui_state.get('filter', 'All') + if search_term: + snapshots = [s for s in snapshots + if search_term in s.get('note', '').lower()] + if filter_mode == 'Pinned': + snapshots = [s for s in snapshots if s.get('pinned')] + elif filter_mode == 'Auto': + snapshots = [s for s in snapshots if s.get('auto')] -_graphviz_svg_cache: dict[str, str] = {} -_GRAPHVIZ_CACHE_MAX = 20 - - -def _render_graphviz(dot_source: str, selected_node_id: str | None = None): - """Render graphviz DOT source as interactive SVG with click-to-select.""" - try: - import graphviz - t_gv = time.perf_counter() - cache_key = hashlib.md5(dot_source.encode()).hexdigest() - svg = _graphviz_svg_cache.get(cache_key) - if svg is None: - src = graphviz.Source(dot_source) - svg = src.pipe(format='svg').decode('utf-8') - if len(_graphviz_svg_cache) >= _GRAPHVIZ_CACHE_MAX: - _graphviz_svg_cache.pop(next(iter(_graphviz_svg_cache))) - _graphviz_svg_cache[cache_key] = svg - logger.info("_render_graphviz MISS (generated): %.3fs", time.perf_counter() - t_gv) - else: - logger.info("_render_graphviz HIT (cached): %.3fs", time.perf_counter() - t_gv) - - sel_escaped = json.dumps(selected_node_id or '')[1:-1] # strip quotes, get JS-safe content - - # CSS inline (allowed), JS via run_javascript (script tags blocked) - css = '''''' - - ui.html( - f'{css}
' - f'{svg}
' - ) - - # Find container by class with retry for Vue async render - ui.run_javascript(f''' - (function attempt(tries) {{ - var container = document.querySelector('.timeline-graph'); - if (!container || !container.querySelector('g.node')) {{ - if (tries < 20) setTimeout(function() {{ attempt(tries + 1); }}, 100); - return; - }} - container.querySelectorAll('g.node').forEach(function(g) {{ - g.addEventListener('click', function() {{ - var title = g.querySelector('title'); - if (title) {{ - window.graphSelectedNode = title.textContent.trim(); - container.querySelectorAll('g.node.selected').forEach( - function(el) {{ el.classList.remove('selected'); }}); - g.classList.add('selected'); - }} - }}); - }}); - var selId = '{sel_escaped}'; - if (selId) {{ - container.querySelectorAll('g.node').forEach(function(g) {{ - var title = g.querySelector('title'); - if (title && title.textContent.trim() === selId) {{ - g.classList.add('selected'); - }} - }}); - }} - }})(0); - ''') - except ImportError: - ui.label('Install graphviz Python package for graph rendering.').classes('text-warning') - ui.code(dot_source).classes('w-full') - except Exception as e: - ui.label(f'Graph rendering error: {e}').classes('text-negative') - - -async def _restore_node(data, node, htree, file_path, state: AppState): - """Restore a history node as the current version (full replace, not merge).""" - t0 = time.perf_counter() - logger.info("_restore_node START: %s", node.get('note', 'Step')) - # Load snapshot from DB on demand (nodes no longer hold data in memory) - raw_snap = node.get('data') - if not raw_snap and state.db_enabled and state.db and state.current_project: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - raw_snap = await asyncio.to_thread( - state.db.get_node_snapshot, df['id'], node['id']) - if not raw_snap: - # Last resort: read from JSON file on disk - raw_file, _ = await asyncio.to_thread(load_json, file_path) - tree_on_disk = raw_file.get(KEY_HISTORY_TREE, {}) - raw_snap = tree_on_disk.get('nodes', {}).get(node['id'], {}).get('data', {}) - node_data = json.loads(json.dumps(raw_snap)) if raw_snap else {} - # Preserve the history tree before clearing - preserved_tree = data.get(KEY_HISTORY_TREE) - preserved_backup = data.get('history_tree_backup') - data.clear() - data.update(node_data) - # Re-attach history tree (not part of snapshot data) - if preserved_tree is not None: - data[KEY_HISTORY_TREE] = preserved_tree - if preserved_backup is not None: - data['history_tree_backup'] = preserved_backup - htree.head_id = node['id'] - data[KEY_HISTORY_TREE] = htree.to_dict() - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - label = f"{node.get('note', 'Step')} ({node['id'][:4]})" - state.restored_indicator = label - logger.info("_restore_node END (%.3fs)", time.perf_counter() - t0) - ui.notify('Restored!', type='positive') - - -def _render_data_preview(nid, htree, state: AppState = None, file_path=None): - """Render a read-only preview of the selected node's data.""" - if not nid or nid not in htree.nodes: - ui.label('No node selected.').classes('text-caption') + if not snapshots: + ui.label('No snapshots match your filter.').classes('text-caption q-pa-md') return - # Load snapshot from DB on demand (not stored in memory) - node_data = htree.nodes[nid].get('data') - if not node_data and state and state.db_enabled and state.db and state.current_project and file_path: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - node_data = state.db.get_node_snapshot(df['id'], nid) - if not node_data and file_path: - # Disk fallback: read snapshot from JSON file - try: - raw_data, _ = load_json(file_path) - tree_on_disk = raw_data.get(KEY_HISTORY_TREE, {}) - node_data = tree_on_disk.get('nodes', {}).get(nid, {}).get('data') - except Exception: - pass - if not node_data: + with ui.scroll_area().classes('w-full').style('flex: 1; min-height: 0'): + for snap in snapshots: + sid = snap['id'] + is_current = sid == timeline.current_id + is_selected = sid == ui_state.get('selected_id') + is_pinned = snap.get('pinned', False) + is_auto = snap.get('auto', False) + + # Card styling + border = '' + if is_current: + border = 'border-left: 4px solid #eebb00;' + if is_selected: + border = 'border-left: 4px solid #4caf50;' + bg = 'background: rgba(76,175,80,0.08) !important;' if is_selected else '' + + def select_snap(snap_id=sid): + ui_state['selected_id'] = snap_id + state.timeline_selected_id = snap_id + refresh_list.refresh() + detail = panels.get('detail') + if detail is not None: + detail.refresh() + + with ui.card().classes('w-full q-mb-xs q-pa-xs cursor-pointer').style( + f'{border} {bg}').on('click', select_snap): + with ui.row().classes('w-full items-center no-wrap'): + # Icon + if is_pinned: + icon_name = 'push_pin' + icon_cls = 'text-amber' + elif is_auto: + icon_name = 'bolt' + icon_cls = 'text-grey' + else: + icon_name = 'save' + icon_cls = 'text-primary' + ui.icon(icon_name, size='sm').classes(icon_cls) + + # Text + with ui.column().classes('col q-ml-xs').style('min-width: 0'): + note = snap.get('note', 'Snapshot') + lbl = ui.label(note).classes('text-body2 ellipsis') + if is_current: + lbl.classes('text-bold') + ts = time.strftime('%b %d %H:%M', + time.localtime(snap['timestamp'])) + seq_count = snap.get('seq_count', '?') + ui.label(f'{ts} \u00b7 {seq_count} seqs').classes( + 'text-caption text-grey') + + # Badges + if is_current: + ui.badge('current', color='amber').props('dense') + + # Pin toggle + async def toggle_pin(snap_id=sid): + timeline.toggle_pin(snap_id) + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + refresh_list.refresh() + + pin_icon = 'push_pin' if is_pinned else 'o_push_pin' + ui.button(icon=pin_icon, on_click=toggle_pin).props( + 'flat dense round size=xs').on('click.stop', lambda: None) + + +# ====================================================================== +# Right panel: detail tabs (Preview / Compare / Cherry-pick) +# ====================================================================== + +def _render_detail_panel(timeline, ui_state, data, file_path, state, + panels): + sel_id = ui_state.get('selected_id') + if not sel_id or sel_id not in timeline.snapshots: + ui.label('Select a snapshot from the list.').classes('text-caption q-pa-lg') + return + + def _refresh_both(): + """Refresh both list and detail panels.""" + lp = panels.get('list') + dp = panels.get('detail') + if lp: + lp.refresh() + if dp: + dp.refresh() + + snap = timeline.snapshots[sel_id] + note = snap.get('note', 'Snapshot') + ts = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(snap['timestamp'])) + ui.label(f'{note}').classes('text-subtitle1 text-bold') + ui.label(f'{ts} \u2022 ID: {sel_id}').classes('text-caption text-grey q-mb-sm') + + # Action buttons + with ui.row().classes('q-gutter-sm q-mb-sm'): + is_current = sel_id == timeline.current_id + + if not is_current: + async def restore_full(): + await _restore_snapshot(data, sel_id, timeline, file_path, state) + state._render_main.refresh() + + ui.button('Restore Full', icon='restore', + on_click=restore_full).props('color=primary dense') + + # Rename + rename_input = ui.input(placeholder='New note...').props('dense outlined').classes('w-48') + + async def rename(): + if rename_input.value and sel_id in timeline.snapshots: + timeline.snapshots[sel_id]['note'] = rename_input.value + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + ui.notify('Note updated', type='positive') + _refresh_both() + + ui.button('Rename', on_click=rename).props('flat dense') + + # Delete + async def delete_snap(): + timeline.delete(sel_id) + # Clean up DB snapshots + if state.db_enabled and state.db and state.current_project: + df = state.db.get_data_file_by_names(state.current_project, file_path.stem) + if df: + state.db.delete_node_snapshots(df['id'], {sel_id}) + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + ui_state['selected_id'] = timeline.current_id + state.timeline_selected_id = timeline.current_id + ui.notify('Snapshot deleted', type='positive') + _refresh_both() + + ui.button(icon='delete', on_click=delete_snap).props('flat dense color=negative') + + # Sub-tabs + with ui.tabs().classes('w-full') as tabs: + preview_tab = ui.tab('Preview', icon='visibility') + compare_tab = ui.tab('Compare', icon='compare') + cherry_tab = ui.tab('Cherry-pick', icon='content_paste') + + with ui.tab_panels(tabs, value=preview_tab).classes('w-full'): + with ui.tab_panel(preview_tab): + _render_preview_tab(sel_id, timeline, state, file_path) + + with ui.tab_panel(compare_tab): + _render_compare_tab(sel_id, timeline, data, state, file_path) + + with ui.tab_panel(cherry_tab): + _render_cherry_pick_tab(sel_id, timeline, data, file_path, state, + panels) + + +# ====================================================================== +# Tab 1: Preview +# ====================================================================== + +def _render_preview_tab(sel_id, timeline, state, file_path): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: ui.label('Snapshot data not available.').classes('text-caption text-warning') return - batch_list = node_data.get(KEY_BATCH_DATA, []) - if batch_list and isinstance(batch_list, list) and len(batch_list) > 0: - ui.label(f'This snapshot contains {len(batch_list)} sequences.').classes('text-caption') + batch_list = snap_data.get(KEY_BATCH_DATA, []) + if batch_list and isinstance(batch_list, list): + ui.label(f'{len(batch_list)} sequences in this snapshot.').classes('text-caption') for i, seq_data in enumerate(batch_list): seq_num = seq_data.get('sequence_number', i + 1) with ui.expansion(f'Sequence #{seq_num}', value=(i == 0)): _render_preview_fields(seq_data) else: - _render_preview_fields(node_data) + _render_preview_fields(snap_data) + + +# ====================================================================== +# Tab 2: Compare +# ====================================================================== + +def _render_compare_tab(sel_id, timeline, data, state, file_path): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: + ui.label('Snapshot data not available.').classes('text-caption text-warning') + return + + old_batch = snap_data.get(KEY_BATCH_DATA, []) + new_batch = data.get(KEY_BATCH_DATA, []) + + if not old_batch and not new_batch: + ui.label('No batch data to compare.').classes('text-caption') + return + + diffs = diff_snapshots(old_batch, new_batch) + + show_all = ui.switch('Show unchanged', value=False) + + @ui.refreshable + def render_diff(): + any_diff = False + for d in diffs: + if d['status'] == 'unchanged' and not show_all.value: + continue + any_diff = True + seq_num = d['seq_num'] + status = d['status'] + + # Header styling + if status == 'added': + icon = 'add_circle' + color = 'text-positive' + label = f'Sequence #{seq_num} \u2014 ADDED (not in snapshot)' + elif status == 'removed': + icon = 'remove_circle' + color = 'text-negative' + label = f'Sequence #{seq_num} \u2014 REMOVED (not in current)' + elif status == 'changed': + icon = 'change_circle' + color = 'text-warning' + label = f'Sequence #{seq_num} \u2014 {len(d["changes"])} field{"s" if len(d["changes"]) != 1 else ""} changed' + else: + icon = 'check_circle' + color = 'text-grey' + label = f'Sequence #{seq_num} \u2014 No changes' + + with ui.expansion(label, icon=icon).classes(f'w-full {color}'): + if status == 'changed' and d['changes']: + # Table of field changes + columns = [ + {'name': 'field', 'label': 'Field', 'field': 'field', 'align': 'left'}, + {'name': 'old', 'label': 'Snapshot', 'field': 'old', 'align': 'left'}, + {'name': 'new', 'label': 'Current', 'field': 'new', 'align': 'left'}, + ] + rows = [] + for c in d['changes']: + rows.append({ + 'field': c['field'], + 'old': _truncate(c['old']), + 'new': _truncate(c['new']), + }) + ui.table(columns=columns, rows=rows, row_key='field').classes( + 'w-full').props('dense flat bordered') + elif status in ('added', 'removed'): + ui.label('Entire sequence differs.').classes('text-caption') + + if not any_diff: + ui.label('All sequences are identical.').classes('text-caption q-pa-md') + + show_all.on_value_change(lambda _: render_diff.refresh()) + render_diff() + + +# ====================================================================== +# Tab 3: Cherry-pick Restore +# ====================================================================== + +def _render_cherry_pick_tab(sel_id, timeline, data, file_path, state, + panels): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: + ui.label('Snapshot data not available.').classes('text-caption text-warning') + return + + old_batch = snap_data.get(KEY_BATCH_DATA, []) + if not old_batch: + ui.label('No sequences in this snapshot.').classes('text-caption') + return + + ui.label('Select sequences and fields to restore from this snapshot.').classes( + 'text-caption q-mb-sm') + + mode = ui.toggle(['Whole sequences', 'Selected fields'], value='Whole sequences').props( + 'dense no-caps') + + # Build checkboxes per sequence + seq_checks: dict[int, ui.checkbox] = {} + field_checks: dict[int, dict[str, ui.checkbox]] = {} + + for seq_item in old_batch: + seq_num = int(seq_item.get('sequence_number', 0)) + seq_cb = ui.checkbox(f'Sequence #{seq_num}') + seq_checks[seq_num] = seq_cb + + with ui.expansion(f'Fields for #{seq_num}').classes('w-full q-ml-lg'): + field_checks[seq_num] = {} + for k in sorted(seq_item.keys()): + if k == 'sequence_number': + continue + val_str = _truncate(seq_item.get(k)) + fcb = ui.checkbox(f'{k}: {val_str}') + field_checks[seq_num][k] = fcb + + async def apply_cherry_pick(): + current_batch = data.get(KEY_BATCH_DATA, []) + curr_by_seq = {int(s.get('sequence_number', 0)): s for s in current_batch} + old_by_seq = {int(s.get('sequence_number', 0)): s for s in old_batch} + + applied = 0 + for seq_num, cb in seq_checks.items(): + if not cb.value: + continue + if seq_num not in old_by_seq: + continue + + if mode.value == 'Whole sequences': + # Replace or add entire sequence + restored = copy.deepcopy(old_by_seq[seq_num]) + if seq_num in curr_by_seq: + # Find and replace in-place + for i, s in enumerate(current_batch): + if int(s.get('sequence_number', 0)) == seq_num: + current_batch[i] = restored + break + else: + current_batch.append(restored) + applied += 1 + else: + # Selected fields only + if seq_num not in curr_by_seq: + continue + target = curr_by_seq[seq_num] + fields = field_checks.get(seq_num, {}) + for field_name, fcb in fields.items(): + if fcb.value and field_name in old_by_seq[seq_num]: + target[field_name] = copy.deepcopy(old_by_seq[seq_num][field_name]) + applied += 1 + + if applied == 0: + ui.notify('Nothing selected to restore.', type='warning') + return + + data[KEY_BATCH_DATA] = current_batch + + # Auto-snapshot noting the cherry-pick + snap_note = timeline.snapshots.get(sel_id, {}).get('note', 'unknown') + snap_json = json.dumps({k: v for k, v in data.items() + if k != KEY_HISTORY_TREE}) + snap_payload = json.loads(snap_json) + timeline.record(snap_payload, note=f'Cherry-pick from "{snap_note}"') + if state.db_enabled and state.current_project and state.db: + data[KEY_HISTORY_TREE] = timeline.to_dict() + db_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap) + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() + save_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, save_snap) + ui.notify(f'Applied {applied} item{"s" if applied != 1 else ""}!', type='positive') + for p in ('list', 'detail'): + ref = panels.get(p) + if ref: + ref.refresh() + + ui.button('Apply Selected', icon='check', on_click=apply_cherry_pick).props( + 'color=primary q-mt-md') + + +# ====================================================================== +# Shared helpers +# ====================================================================== + +def _load_snapshot_data(snap_id, timeline, state, file_path): + """Load snapshot data from inline, DB, or disk fallback.""" + snap_data = timeline.get_snapshot_data(snap_id) + if snap_data: + return snap_data + + # Try DB + if state and state.db_enabled and state.db and state.current_project and file_path: + df = state.db.get_data_file_by_names(state.current_project, file_path.stem) + if df: + snap_data = state.db.get_node_snapshot(df['id'], snap_id) + if snap_data: + return snap_data + + # Disk fallback + if file_path: + try: + raw_data, _ = load_json(file_path) + tree_on_disk = raw_data.get(KEY_HISTORY_TREE, {}) + # New format + entry = tree_on_disk.get('snapshots', {}).get(snap_id) + if entry and 'data' in entry: + return entry['data'] + # Old format + entry = tree_on_disk.get('nodes', {}).get(snap_id) + if entry and 'data' in entry: + return entry['data'] + except Exception as e: + logger.warning("Failed to load snapshot %s from disk: %s", snap_id, e) + return None + + +async def _restore_snapshot(data, snap_id, timeline, file_path, state): + """Restore a snapshot as the current version (full replace).""" + snap_data = _load_snapshot_data(snap_id, timeline, state, file_path) + if not snap_data: + ui.notify('Snapshot data not available', type='negative') + return + + node_data = json.loads(json.dumps(snap_data)) + + # Preserve history tree + preserved_tree = data.get(KEY_HISTORY_TREE) + preserved_backup = data.get('history_tree_backup') + data.clear() + data.update(node_data) + if preserved_tree is not None: + data[KEY_HISTORY_TREE] = preserved_tree + if preserved_backup is not None: + data['history_tree_backup'] = preserved_backup + + timeline.current_id = snap_id + data[KEY_HISTORY_TREE] = timeline.to_dict() + + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + + note = timeline.snapshots.get(snap_id, {}).get('note', 'Snapshot') + label = f"{note} ({snap_id[:4]})" + state.restored_indicator = label + ui.notify('Restored!', type='positive') def _render_preview_fields(item_data: dict): @@ -684,3 +601,9 @@ def _render_preview_fields(item_data: dict): value=str(item_data.get('vace schedule', ''))).props('readonly outlined') ui.input('Video Path', value=str(item_data.get('video file path', ''))).props('readonly outlined') + + +def _truncate(val, max_len=60): + """Truncate a value for display.""" + s = str(val) if val is not None else '' + return (s[:max_len] + '...') if len(s) > max_len else s diff --git a/tests/test_db.py b/tests/test_db.py index bea102f..e027ea6 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -208,10 +208,10 @@ class TestHistoryTrees: def test_upsert_updates(self, db): pid = db.create_project("p1", "/p1") df_id = db.create_data_file(pid, "batch", "generic") - db.save_history_tree(df_id, {"v": 1}) - db.save_history_tree(df_id, {"v": 2}) + db.save_history_tree(df_id, {"snapshots": {}, "v": 1}) + db.save_history_tree(df_id, {"snapshots": {}, "v": 2}) result = db.get_history_tree(df_id) - assert result == {"v": 2} + assert result == {"snapshots": {}, "v": 2} def test_get_nonexistent(self, db): pid = db.create_project("p1", "/p1") diff --git a/tests/test_snapshot_timeline.py b/tests/test_snapshot_timeline.py new file mode 100644 index 0000000..db02b1d --- /dev/null +++ b/tests/test_snapshot_timeline.py @@ -0,0 +1,159 @@ +import pytest +from snapshot_timeline import SnapshotTimeline, diff_snapshots + + +def test_record_creates_snapshot(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": [{"seed": 42}]}, note="first") + assert sid in tl.snapshots + assert tl.current_id == sid + assert tl.snapshots[sid]["note"] == "first" + assert tl.snapshots[sid]["auto"] is False + assert tl.snapshots[sid]["seq_count"] == 1 + + +def test_record_auto_flag(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="auto save", auto=True) + assert tl.snapshots[sid]["auto"] is True + + +def test_multiple_records(): + tl = SnapshotTimeline({}) + id1 = tl.record({"batch_data": [{"a": 1}]}, note="one") + id2 = tl.record({"batch_data": [{"b": 2}]}, note="two") + assert len(tl.snapshots) == 2 + assert tl.current_id == id2 + + +def test_to_dict_roundtrip(): + tl = SnapshotTimeline({}) + tl.record({"batch_data": [{"x": 1}]}, note="test") + d = tl.to_dict() + tl2 = SnapshotTimeline(d) + assert tl2.current_id == tl.current_id + assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys()) + + +def test_migrate_from_history_tree(): + """Old HistoryTree format should be flattened into snapshots.""" + old_data = { + "nodes": { + "aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}}, + "bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}}, + }, + "branches": {"main": "bbb"}, + "head_id": "bbb", + } + tl = SnapshotTimeline(old_data) + assert len(tl.snapshots) == 2 + assert tl.current_id == "bbb" + assert tl.snapshots["aaa"]["note"] == "First" + assert tl.snapshots["bbb"]["note"] == "Second" + # Data should be preserved + assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}] + + +def test_migrate_from_history_tree_no_data(): + """Slim tree nodes (no inline data) should still migrate.""" + old_data = { + "nodes": { + "aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"}, + }, + "branches": {"main": "aaa"}, + "head_id": "aaa", + } + tl = SnapshotTimeline(old_data) + assert len(tl.snapshots) == 1 + assert tl.snapshots["aaa"]["seq_count"] == 0 + + +def test_migrate_legacy_prompt_history(): + legacy = { + "prompt_history": [ + {"note": "A", "seed": 1}, + {"note": "B", "seed": 2}, + ] + } + tl = SnapshotTimeline(legacy) + assert len(tl.snapshots) == 2 + assert tl.current_id is not None + + +def test_toggle_pin(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="test") + assert tl.snapshots[sid]["pinned"] is False + result = tl.toggle_pin(sid) + assert result is True + assert tl.snapshots[sid]["pinned"] is True + result = tl.toggle_pin(sid) + assert result is False + + +def test_delete_snapshot(): + tl = SnapshotTimeline({}) + id1 = tl.record({"batch_data": []}, note="one") + id2 = tl.record({"batch_data": []}, note="two") + tl.delete(id2) + assert id2 not in tl.snapshots + assert tl.current_id == id1 + + +def test_delete_all_snapshots(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="only") + tl.delete(sid) + assert len(tl.snapshots) == 0 + assert tl.current_id is None + + +def test_strip_snapshots(): + tl = SnapshotTimeline({}) + tl.record({"batch_data": [{"a": 1}]}, note="test") + tl.strip_snapshots() + for snap in tl.snapshots.values(): + assert "data" not in snap + + +def test_get_snapshot_data(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": [{"x": 1}]}, note="test") + data = tl.get_snapshot_data(sid) + assert data == {"batch_data": [{"x": 1}]} + assert tl.get_snapshot_data("nonexistent") is None + + +# --- diff_snapshots tests --- + +def test_diff_unchanged(): + batch = [{"sequence_number": 1, "seed": 42}] + result = diff_snapshots(batch, batch) + assert len(result) == 1 + assert result[0]["status"] == "unchanged" + assert result[0]["changes"] == [] + + +def test_diff_changed(): + old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}] + new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}] + result = diff_snapshots(old, new) + assert result[0]["status"] == "changed" + assert len(result[0]["changes"]) == 1 + assert result[0]["changes"][0]["field"] == "seed" + assert result[0]["changes"][0]["old"] == 42 + assert result[0]["changes"][0]["new"] == 99 + + +def test_diff_added_and_removed(): + old = [{"sequence_number": 1, "seed": 1}] + new = [{"sequence_number": 2, "seed": 2}] + result = diff_snapshots(old, new) + assert len(result) == 2 + statuses = {r["seq_num"]: r["status"] for r in result} + assert statuses[1] == "removed" + assert statuses[2] == "added" + + +def test_diff_empty(): + assert diff_snapshots([], []) == [] diff --git a/utils.py b/utils.py index 91f151a..ab5ed85 100644 --- a/utils.py +++ b/utils.py @@ -305,38 +305,47 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: else: db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) - # Sync history tree (extract node snapshots into separate table) + # Sync history tree (extract snapshot data into separate table) + # Supports both new format (snapshots dict) and old format (nodes dict) history_tree = data.get(KEY_HISTORY_TREE) if history_tree and isinstance(history_tree, dict): - nodes = history_tree.get("nodes", {}) + # Detect format: new has "snapshots", old has "nodes" + if "snapshots" in history_tree: + entries = history_tree.get("snapshots", {}) + else: + entries = history_tree.get("nodes", {}) slim_tree = dict(history_tree) - slim_nodes = {} - for nid, node in nodes.items(): - snap = node.get("data") + slim_entries = {} + for eid, entry in entries.items(): + snap = entry.get("data") if snap: db.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (df_id, nid, json.dumps(snap), now), + (df_id, eid, json.dumps(snap), now), ) - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + # Write back slim version using the correct key + if "snapshots" in history_tree: + slim_tree["snapshots"] = slim_entries + else: + slim_tree["nodes"] = slim_entries db.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " "VALUES (?, ?, ?) " "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", (df_id, json.dumps(slim_tree), now), ) - # Clean up orphaned snapshots for nodes no longer in tree - current_node_ids = set(nodes.keys()) - if current_node_ids: - placeholders = ",".join("?" for _ in current_node_ids) + # Clean up orphaned snapshots + current_ids = set(entries.keys()) + if current_ids: + placeholders = ",".join("?" for _ in current_ids) db.conn.execute( f"DELETE FROM history_snapshots WHERE data_file_id = ? " f"AND node_id NOT IN ({placeholders})", - (df_id, *current_node_ids), + (df_id, *current_ids), ) else: db.conn.execute(