Merge branch 'feat/snapshot-timeline'
This commit is contained in:
@@ -242,7 +242,6 @@ class ProjectDB:
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
def _migrate_lora_keys(data: dict) -> dict:
|
||||
"""Split combined lora 'name:strength' into separate name and strength keys."""
|
||||
@@ -340,27 +339,34 @@ class ProjectDB:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
||||
"""Save history tree, extracting node snapshots into separate table."""
|
||||
"""Save history tree, extracting snapshot data into separate table.
|
||||
|
||||
Supports both new format (snapshots dict) and old format (nodes dict).
|
||||
"""
|
||||
now = time.time()
|
||||
nodes = tree_data.get("nodes", {})
|
||||
if "snapshots" in tree_data:
|
||||
entries = tree_data.get("snapshots", {})
|
||||
entry_key = "snapshots"
|
||||
else:
|
||||
entries = tree_data.get("nodes", {})
|
||||
entry_key = "nodes"
|
||||
slim_tree = dict(tree_data)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
slim_tree[entry_key] = slim_entries
|
||||
|
||||
self.conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
# Extract snapshot data from nodes into history_snapshots table
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(data_file_id, nid, json.dumps(snap), now),
|
||||
(data_file_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
@@ -463,24 +469,30 @@ class ProjectDB:
|
||||
)
|
||||
|
||||
# Import history tree (extract snapshots into separate table)
|
||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
now = time.time()
|
||||
nodes = history_tree.get("nodes", {})
|
||||
if "snapshots" in history_tree:
|
||||
entries = history_tree.get("snapshots", {})
|
||||
entry_key = "snapshots"
|
||||
else:
|
||||
entries = history_tree.get("nodes", {})
|
||||
entry_key = "nodes"
|
||||
slim_tree = dict(history_tree)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(df_id, nid, json.dumps(snap), now),
|
||||
(df_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
slim_tree[entry_key] = slim_entries
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
@@ -540,9 +552,9 @@ class ProjectDB:
|
||||
# Load history tree (metadata only, no snapshot data)
|
||||
tree = self.get_history_tree(df["id"])
|
||||
if tree:
|
||||
# Strip any residual snapshot data from nodes
|
||||
for node in tree.get("nodes", {}).values():
|
||||
node.pop("data", None)
|
||||
# Strip any residual snapshot data (supports both formats)
|
||||
for entry in tree.get("snapshots", tree.get("nodes", {})).values():
|
||||
entry.pop("data", None)
|
||||
data["history_tree"] = tree
|
||||
t3 = time.time()
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# Resolution Series Design
|
||||
|
||||
## Problem
|
||||
|
||||
When running ComfyUI loop nodes for multi-step upscaling (e.g. 3+ resolutions at different sizes),
|
||||
managing portrait vs landscape width/height per iteration is tedious. Users need a structured way
|
||||
to define N resolution pairs in the manager UI and retrieve them by loop index in ComfyUI.
|
||||
|
||||
## Design
|
||||
|
||||
### Data Model
|
||||
|
||||
Resolution series are stored as a JSON array under a user-chosen key in the sequence data:
|
||||
|
||||
```json
|
||||
"upscale_resolutions": [[512, 512], [768, 1344], [1344, 768], [2048, 2048]]
|
||||
```
|
||||
|
||||
- Each element is `[width, height]` (both INT)
|
||||
- Key name is chosen by the user (any string)
|
||||
- Number of entries is configurable (add/remove rows)
|
||||
- Stored in the same project JSON file and sequence — no schema change required
|
||||
- Index out of bounds → clamp to last entry
|
||||
|
||||
### NiceGUI UI (tab_batch_ng.py)
|
||||
|
||||
A resolution series editor is rendered in the left column of the sequence card, directly below
|
||||
the "Specific Negative" textarea.
|
||||
|
||||
Layout:
|
||||
|
||||
```
|
||||
── Resolution Series ──────────────────
|
||||
key name: [upscale_resolutions ]
|
||||
# Width Height
|
||||
1 [2048] [2048] [x]
|
||||
2 [768 ] [1344] [x]
|
||||
3 [1344] [768 ] [x]
|
||||
[+ Add row]
|
||||
```
|
||||
|
||||
- Key name is editable (defaults to `resolutions`)
|
||||
- Rows added/removed inline; each change calls `commit()` immediately
|
||||
- Hidden behind an "Add Resolution Series" button when no resolution key exists yet
|
||||
- A value is detected as a resolution series if it is a list of `[int, int]` pairs
|
||||
|
||||
### ComfyUI Node (`ProjectResolution`)
|
||||
|
||||
New node class in `project_loader.py`, sibling to `ProjectKey`.
|
||||
|
||||
**Inputs:**
|
||||
- `source_label` (STRING) — references a `ProjectSource` by label
|
||||
- `key_name` (STRING) — the resolution series key name
|
||||
- `index` (INT, min 0) — wired from loop node's current index output
|
||||
- `manager_url`, `project_name`, `file_name`, `sequence_number` — optional, synced from `ProjectSource` via JS
|
||||
|
||||
**Outputs:** `width` (INT), `height` (INT)
|
||||
|
||||
**Execution:** fetches the sequence data, reads `data[key_name]`, indexes into the array with
|
||||
clamp-to-last on out-of-bounds, returns `(width, height)`.
|
||||
|
||||
**JS (`web/project_resolution.js`):**
|
||||
- Same `_syncFromSource` mechanism as `project_key.js`
|
||||
- `key_name` widget is replaced with a combo dropdown populated with keys whose value is a
|
||||
resolution series (list of `[int, int]` pairs), detected via the existing keys API
|
||||
- Registered in `PROJECT_NODE_CLASS_MAPPINGS` and `PROJECT_NODE_DISPLAY_NAME_MAPPINGS`
|
||||
|
||||
### API
|
||||
|
||||
No new endpoints. Uses existing:
|
||||
- `/json_manager/get_project_keys` — for key discovery (JS combo population)
|
||||
- `_fetch_data()` — for execution-time data fetch
|
||||
|
||||
### Files Changed
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `project_loader.py` | Add `ProjectResolution` class + register in mappings |
|
||||
| `web/project_resolution.js` | New JS extension for the node |
|
||||
| `tab_batch_ng.py` | Resolution series editor below Specific Negative |
|
||||
| `__init__.py` | Register new JS file if needed |
|
||||
@@ -0,0 +1,640 @@
|
||||
# Resolution Series Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Add a `ProjectResolution` ComfyUI node and NiceGUI editor that let users define N `(width, height)` pairs per sequence and retrieve them by loop index.
|
||||
|
||||
**Architecture:** Resolution series are stored as a JSON array of `[width, height]` pairs under a user-chosen key in sequence data (e.g. `"upscale_resolutions": [[512,512],[768,1344]]`). A new `ProjectResolution` ComfyUI node (sibling of `ProjectKey`) accepts a `source_label`, `key_name`, and `index` INT from a loop node, and returns `width` + `height`. The NiceGUI sequence card gets an inline table editor placed directly below the "Specific Negative" textarea.
|
||||
|
||||
**Tech Stack:** Python (ComfyUI node), NiceGUI (UI), JavaScript (ComfyUI frontend extension), pytest
|
||||
|
||||
**Branch:** Create and work on `feat/resolution-series` branched from `main`:
|
||||
```bash
|
||||
git checkout main && git checkout -b feat/resolution-series
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 0: Fix pre-existing test failures on `main`
|
||||
|
||||
When `file_name` was added as a second output to `ProjectSource`, two tests were not updated.
|
||||
They fail on `main` before any new code is written.
|
||||
|
||||
**Files:**
|
||||
- Modify: `tests/test_project_loader.py` (`TestProjectSource` class, lines ~216-231)
|
||||
|
||||
**Step 1: Update the two broken tests**
|
||||
|
||||
```python
|
||||
def test_outputs_sequence_number(self):
|
||||
from project_loader import ProjectSource
|
||||
assert ProjectSource.RETURN_TYPES == ("INT", "STRING",)
|
||||
assert ProjectSource.RETURN_NAMES == ("sequence_number", "file_name",)
|
||||
|
||||
def test_hold_config_returns_sequence_number(self):
|
||||
from project_loader import ProjectSource
|
||||
node = ProjectSource()
|
||||
result = node.hold_config(
|
||||
manager_url="http://localhost:8080",
|
||||
project_name="proj1",
|
||||
file_name="batch_i2v",
|
||||
sequence_number=42,
|
||||
label="my_source"
|
||||
)
|
||||
assert result == (42, "batch_i2v")
|
||||
```
|
||||
|
||||
**Step 2: Verify they now pass**
|
||||
|
||||
```bash
|
||||
pytest tests/test_project_loader.py::TestProjectSource -v
|
||||
```
|
||||
Expected: all 4 PASS
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add tests/test_project_loader.py
|
||||
git commit -m "fix: update ProjectSource tests for file_name output"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Python node — `ProjectResolution`
|
||||
|
||||
**Files:**
|
||||
- Modify: `project_loader.py` (after the `ProjectKey` class, before `# --- Mappings ---`)
|
||||
- Modify: `tests/test_project_loader.py` (add `TestProjectResolution` class)
|
||||
|
||||
**Step 1: Write failing tests**
|
||||
|
||||
Add this class to `tests/test_project_loader.py`:
|
||||
|
||||
```python
|
||||
class TestProjectResolution:
|
||||
def test_input_types(self):
|
||||
from project_loader import ProjectResolution
|
||||
inputs = ProjectResolution.INPUT_TYPES()
|
||||
assert "source_label" in inputs["required"]
|
||||
assert "key_name" in inputs["required"]
|
||||
assert "index" in inputs["required"]
|
||||
assert inputs["required"]["index"][0] == "INT"
|
||||
|
||||
def test_two_outputs(self):
|
||||
from project_loader import ProjectResolution
|
||||
assert ProjectResolution.RETURN_TYPES == ("INT", "INT")
|
||||
assert ProjectResolution.RETURN_NAMES == ("width", "height")
|
||||
|
||||
def test_fetch_resolution_basic(self):
|
||||
from project_loader import ProjectResolution
|
||||
node = ProjectResolution()
|
||||
data = {"resolutions": [[512, 512], [768, 1344], [1344, 768]]}
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.fetch_resolution(
|
||||
source_label="src", key_name="resolutions", index=1,
|
||||
manager_url="http://localhost:8080", project_name="p",
|
||||
file_name="f", sequence_number=1,
|
||||
)
|
||||
assert result == (768, 1344)
|
||||
|
||||
def test_fetch_resolution_index_zero(self):
|
||||
from project_loader import ProjectResolution
|
||||
node = ProjectResolution()
|
||||
data = {"resolutions": [[512, 512], [1024, 1024]]}
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.fetch_resolution(
|
||||
source_label="src", key_name="resolutions", index=0,
|
||||
manager_url="http://localhost:8080", project_name="p",
|
||||
file_name="f", sequence_number=1,
|
||||
)
|
||||
assert result == (512, 512)
|
||||
|
||||
def test_fetch_resolution_clamps_on_out_of_bounds(self):
|
||||
from project_loader import ProjectResolution
|
||||
node = ProjectResolution()
|
||||
data = {"resolutions": [[512, 512], [1024, 1024]]}
|
||||
with patch("project_loader._fetch_data", return_value=data):
|
||||
result = node.fetch_resolution(
|
||||
source_label="src", key_name="resolutions", index=99,
|
||||
manager_url="http://localhost:8080", project_name="p",
|
||||
file_name="f", sequence_number=1,
|
||||
)
|
||||
assert result == (1024, 1024) # last entry
|
||||
|
||||
def test_fetch_resolution_missing_key_returns_defaults(self):
|
||||
from project_loader import ProjectResolution
|
||||
node = ProjectResolution()
|
||||
with patch("project_loader._fetch_data", return_value={}):
|
||||
result = node.fetch_resolution(
|
||||
source_label="src", key_name="nonexistent", index=0,
|
||||
manager_url="http://localhost:8080", project_name="p",
|
||||
file_name="f", sequence_number=1,
|
||||
)
|
||||
assert result == (512, 512)
|
||||
|
||||
def test_fetch_resolution_network_error_returns_defaults(self):
|
||||
from project_loader import ProjectResolution
|
||||
node = ProjectResolution()
|
||||
error_resp = {"error": "network_error", "message": "Connection refused"}
|
||||
with patch("project_loader._fetch_data", return_value=error_resp):
|
||||
result = node.fetch_resolution(
|
||||
source_label="src", key_name="resolutions", index=0,
|
||||
manager_url="http://localhost:8080", project_name="p",
|
||||
file_name="f", sequence_number=1,
|
||||
)
|
||||
assert result == (512, 512)
|
||||
|
||||
def test_category(self):
|
||||
from project_loader import ProjectResolution
|
||||
assert ProjectResolution.CATEGORY == "utils/json/project"
|
||||
```
|
||||
|
||||
**Step 2: Run tests to verify they fail**
|
||||
|
||||
```bash
|
||||
pytest tests/test_project_loader.py::TestProjectResolution -v
|
||||
```
|
||||
Expected: `ImportError: cannot import name 'ProjectResolution'`
|
||||
|
||||
**Step 3: Implement `ProjectResolution` in `project_loader.py`**
|
||||
|
||||
Insert this class after `ProjectKey` (line ~294), before `# --- Mappings ---`:
|
||||
|
||||
```python
|
||||
class ProjectResolution:
|
||||
"""Fetches a (width, height) pair from a resolution series by loop index."""
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"source_label": ("STRING", {"default": "", "multiline": False}),
|
||||
"key_name": ("STRING", {"default": "resolutions", "multiline": False}),
|
||||
"index": ("INT", {"default": 0, "min": 0, "max": 9999}),
|
||||
},
|
||||
"optional": {
|
||||
"manager_url": ("STRING", {"default": "http://localhost:8080", "multiline": False}),
|
||||
"project_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"file_name": ("STRING", {"default": "", "multiline": False}),
|
||||
"sequence_number": ("INT", {"default": 1, "min": 1, "max": 9999}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT", "INT")
|
||||
RETURN_NAMES = ("width", "height")
|
||||
FUNCTION = "fetch_resolution"
|
||||
CATEGORY = "utils/json/project"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, **kwargs):
|
||||
return float("nan")
|
||||
|
||||
def fetch_resolution(self, source_label, key_name, index,
|
||||
manager_url="http://localhost:8080", project_name="",
|
||||
file_name="", sequence_number=1):
|
||||
sequence_number = int(sequence_number)
|
||||
data = _fetch_data(manager_url, project_name, file_name, sequence_number)
|
||||
if data.get("error") in ("http_error", "network_error", "parse_error"):
|
||||
logger.warning("ProjectResolution.fetch_resolution failed: %s", data.get("message"))
|
||||
return (512, 512)
|
||||
|
||||
series = data.get(key_name)
|
||||
if not isinstance(series, list) or len(series) == 0:
|
||||
logger.warning("ProjectResolution: key '%s' is not a resolution series", key_name)
|
||||
return (512, 512)
|
||||
|
||||
clamped = min(index, len(series) - 1)
|
||||
entry = series[clamped]
|
||||
if not isinstance(entry, (list, tuple)) or len(entry) < 2:
|
||||
logger.warning("ProjectResolution: entry at index %d is malformed: %r", clamped, entry)
|
||||
return (512, 512)
|
||||
|
||||
return (to_int(entry[0]), to_int(entry[1]))
|
||||
```
|
||||
|
||||
**Step 4: Run tests to verify they pass**
|
||||
|
||||
```bash
|
||||
pytest tests/test_project_loader.py::TestProjectResolution -v
|
||||
```
|
||||
Expected: all 7 tests PASS
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add project_loader.py tests/test_project_loader.py
|
||||
git commit -m "feat: add ProjectResolution node"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Register `ProjectResolution` in mappings + fix mapping tests
|
||||
|
||||
**Files:**
|
||||
- Modify: `project_loader.py` (mappings section, lines ~297-307)
|
||||
- Modify: `tests/test_project_loader.py` (`TestNodeMappings` class)
|
||||
|
||||
**Step 1: Update mappings in `project_loader.py`**
|
||||
|
||||
Change the mappings at the bottom of the file:
|
||||
|
||||
```python
|
||||
PROJECT_NODE_CLASS_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": ProjectLoaderDynamic,
|
||||
"ProjectSource": ProjectSource,
|
||||
"ProjectKey": ProjectKey,
|
||||
"ProjectResolution": ProjectResolution,
|
||||
}
|
||||
|
||||
PROJECT_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ProjectLoaderDynamic": "Project Loader (Dynamic)",
|
||||
"ProjectSource": "Project Source",
|
||||
"ProjectKey": "Project Key",
|
||||
"ProjectResolution": "Project Resolution",
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Update the mapping test**
|
||||
|
||||
In `tests/test_project_loader.py`, update `TestNodeMappings.test_mappings_exist`:
|
||||
|
||||
```python
|
||||
class TestNodeMappings:
|
||||
def test_mappings_exist(self):
|
||||
from project_loader import PROJECT_NODE_CLASS_MAPPINGS, PROJECT_NODE_DISPLAY_NAME_MAPPINGS
|
||||
assert "ProjectLoaderDynamic" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectSource" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectKey" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert "ProjectResolution" in PROJECT_NODE_CLASS_MAPPINGS
|
||||
assert len(PROJECT_NODE_CLASS_MAPPINGS) == 4
|
||||
assert len(PROJECT_NODE_DISPLAY_NAME_MAPPINGS) == 4
|
||||
```
|
||||
|
||||
**Step 3: Run all project_loader tests**
|
||||
|
||||
```bash
|
||||
pytest tests/test_project_loader.py -v
|
||||
```
|
||||
Expected: all tests PASS
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add project_loader.py tests/test_project_loader.py
|
||||
git commit -m "feat: register ProjectResolution in node mappings"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: NiceGUI resolution series editor in `tab_batch_ng.py`
|
||||
|
||||
**Files:**
|
||||
- Modify: `tab_batch_ng.py`
|
||||
|
||||
The resolution series editor goes inside `splitter.before`, directly after the "Specific Negative" textarea (currently line ~552-553). No new file needed.
|
||||
|
||||
**Step 1: Add the helper function**
|
||||
|
||||
Add this function near the other helper functions at the top of the render section (before `_render_sequence_card`):
|
||||
|
||||
```python
|
||||
def _is_resolution_series(val) -> bool:
|
||||
"""Return True if val is a list of [width, height] int pairs."""
|
||||
if not isinstance(val, list) or len(val) == 0:
|
||||
return False
|
||||
return all(
|
||||
isinstance(entry, (list, tuple)) and len(entry) == 2
|
||||
and all(isinstance(v, (int, float)) for v in entry)
|
||||
for entry in val
|
||||
)
|
||||
```
|
||||
|
||||
Note: `Any` is intentionally omitted — `tab_batch_ng.py` does not import `typing.Any`.
|
||||
|
||||
**Step 2: Add the resolution series render section**
|
||||
|
||||
After the "Specific Negative" textarea in `splitter.before` (after line ~553), add:
|
||||
|
||||
```python
|
||||
# --- Resolution Series ---
|
||||
res_keys = [k for k, v in seq.items() if _is_resolution_series(v)]
|
||||
if res_keys:
|
||||
ui.label('Resolution Series').classes('text-caption text-weight-bold q-mt-md')
|
||||
for res_key in res_keys:
|
||||
series: list = seq[res_key]
|
||||
with ui.card().classes('w-full q-pa-sm q-mt-xs').props('flat bordered'):
|
||||
with ui.row().classes('items-center q-mb-xs'):
|
||||
ui.label(res_key).classes('text-caption col')
|
||||
def del_series(k=res_key):
|
||||
del seq[k]
|
||||
commit()
|
||||
ui.button(icon='delete', on_click=del_series).props(
|
||||
'flat dense round size=xs color=negative')
|
||||
with ui.row().classes('text-caption text-grey q-mb-xs'):
|
||||
ui.label('#').style('width:24px')
|
||||
ui.label('Width').classes('col')
|
||||
ui.label('Height').classes('col')
|
||||
ui.label('').style('width:28px')
|
||||
for idx, entry in enumerate(series):
|
||||
with ui.row().classes('items-center w-full'):
|
||||
ui.label(str(idx + 1)).classes('text-caption').style('width:24px')
|
||||
w_inp = ui.number(value=int(entry[0]), min=1, step=1).classes(
|
||||
'col').props('outlined dense hide-bottom-space')
|
||||
h_inp = ui.number(value=int(entry[1]), min=1, step=1).classes(
|
||||
'col').props('outlined dense hide-bottom-space')
|
||||
|
||||
def _sync_wh(i=idx, k=res_key, wi=w_inp, hi=h_inp):
|
||||
seq[k][i] = [
|
||||
int(wi.value) if wi.value else 512,
|
||||
int(hi.value) if hi.value else 512,
|
||||
]
|
||||
commit()
|
||||
|
||||
w_inp.on('blur', lambda _, s=_sync_wh: s())
|
||||
h_inp.on('blur', lambda _, s=_sync_wh: s())
|
||||
|
||||
def del_row(i=idx, k=res_key):
|
||||
seq[k].pop(i)
|
||||
commit()
|
||||
ui.button(icon='remove', on_click=del_row).props(
|
||||
'flat dense round size=xs')
|
||||
|
||||
def add_row(k=res_key):
|
||||
seq[k].append([512, 512])
|
||||
commit()
|
||||
ui.button('+ Add row', icon='add', on_click=add_row).props(
|
||||
'flat dense size=sm').classes('q-mt-xs')
|
||||
|
||||
with ui.expansion('Add Resolution Series', icon='straighten').classes('w-full q-mt-sm'):
|
||||
new_res_key = ui.input('Key name', value='resolutions').props('outlined dense')
|
||||
def add_res_series():
|
||||
k = new_res_key.value.strip()
|
||||
if k and k not in seq:
|
||||
seq[k] = [[512, 512], [1024, 1024]]
|
||||
commit()
|
||||
ui.button('Add', icon='add', on_click=add_res_series).props('outlined dense')
|
||||
```
|
||||
|
||||
**Step 3: Run all tests**
|
||||
|
||||
```bash
|
||||
pytest tests/ -q
|
||||
```
|
||||
Expected: all tests PASS (no Python tests cover the NiceGUI render path, but no regressions)
|
||||
|
||||
**Important:** Also update the `custom_keys` filter in `_render_sequence_card` (line ~648) to exclude
|
||||
resolution series keys — otherwise they'd render in both the resolution editor AND "Custom Parameters":
|
||||
|
||||
```python
|
||||
# Find this line:
|
||||
custom_keys = [k for k in seq.keys() if k not in standard_keys]
|
||||
# Replace with:
|
||||
custom_keys = [k for k in seq.keys() if k not in standard_keys and not _is_resolution_series(seq.get(k))]
|
||||
```
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add tab_batch_ng.py
|
||||
git commit -m "feat: resolution series editor in sequence card"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: JS extension `web/project_resolution.js`
|
||||
|
||||
**Files:**
|
||||
- Create: `web/project_resolution.js`
|
||||
|
||||
This file mirrors `web/project_key.js` exactly, with two differences:
|
||||
1. It targets `"ProjectResolution"` instead of `"ProjectKey"`
|
||||
2. `_refreshKeys` filters to only show keys whose value is a resolution series (list of `[int, int]` pairs) — but since the keys API only returns key names (not values), the filter is done by naming convention or we just show all keys and let the user pick. For simplicity, show all keys (same as ProjectKey) and let the user pick.
|
||||
3. The `index` widget is **not** hidden — the user wires it from a loop node
|
||||
4. The node has two outputs (`width`, `height`) so no output slot name update is needed
|
||||
|
||||
**Step 1: Create `web/project_resolution.js`**
|
||||
|
||||
```javascript
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "json.manager.project.resolution",
|
||||
|
||||
async beforeQueuePrompt() {
|
||||
if (!app.graph?._nodes) return;
|
||||
for (const node of app.graph._nodes) {
|
||||
if (node.type === "ProjectResolution" && node._syncFromSource) {
|
||||
node._syncFromSource();
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeData.name !== "ProjectResolution") return;
|
||||
|
||||
function hideWidget(widget) {
|
||||
if (widget.origType === undefined) widget.origType = widget.type;
|
||||
widget.type = "hidden";
|
||||
widget.hidden = true;
|
||||
widget.computeSize = () => [0, -4];
|
||||
}
|
||||
|
||||
function replaceWithCombo(node, name, values, callback) {
|
||||
const idx = node.widgets?.findIndex(w => w.name === name);
|
||||
if (idx === -1 || idx === undefined) return null;
|
||||
const oldWidget = node.widgets[idx];
|
||||
const savedValue = oldWidget.value || "";
|
||||
const comboValues = values.length > 0 ? values : [""];
|
||||
if (savedValue && !comboValues.includes(savedValue)) {
|
||||
comboValues.unshift(savedValue);
|
||||
}
|
||||
const defaultValue = savedValue || comboValues[0];
|
||||
node.widgets.splice(idx, 1);
|
||||
const combo = node.addWidget("combo", name, defaultValue, callback, { values: comboValues });
|
||||
if (node.widgets.length > 1) {
|
||||
node.widgets.splice(node.widgets.length - 1, 1);
|
||||
node.widgets.splice(idx, 0, combo);
|
||||
}
|
||||
return combo;
|
||||
}
|
||||
|
||||
const origOnNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
origOnNodeCreated?.apply(this, arguments);
|
||||
this._configured = false;
|
||||
|
||||
// Hide synced config widgets (index stays visible — user wires it)
|
||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
||||
const w = this.widgets?.find(w => w.name === name);
|
||||
if (w) hideWidget(w);
|
||||
}
|
||||
|
||||
const node = this;
|
||||
const sourceLabels = this._getSourceLabels?.() || [];
|
||||
const srcCombo = replaceWithCombo(this, "source_label", sourceLabels, function (value) {
|
||||
node._syncFromSource();
|
||||
node._refreshKeys();
|
||||
});
|
||||
if (srcCombo) srcCombo.value = sourceLabels[0] || "";
|
||||
|
||||
const keyCombo = replaceWithCombo(this, "key_name", [], function (value) {
|
||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
});
|
||||
if (keyCombo) keyCombo.value = "";
|
||||
|
||||
queueMicrotask(() => {
|
||||
if (!this._configured) {
|
||||
this.setSize(this.computeSize());
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
nodeType.prototype._getSourceLabels = function () {
|
||||
const seen = new Set();
|
||||
const labels = [];
|
||||
if (!this.graph) return labels;
|
||||
for (const node of this.graph._nodes) {
|
||||
if (node.type === "ProjectSource") {
|
||||
const lw = node.widgets?.find(w => w.name === "label");
|
||||
if (lw?.value && !seen.has(lw.value)) {
|
||||
seen.add(lw.value);
|
||||
labels.push(lw.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return labels;
|
||||
};
|
||||
|
||||
nodeType.prototype._findSource = function (label) {
|
||||
if (!this.graph || !label) return null;
|
||||
for (const node of this.graph._nodes) {
|
||||
if (node.type === "ProjectSource") {
|
||||
const lw = node.widgets?.find(w => w.name === "label");
|
||||
if (lw?.value === label) return node;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
nodeType.prototype._syncFromSource = function () {
|
||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||
const source = this._findSource(srcWidget?.value);
|
||||
if (!source) return;
|
||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
||||
const dst = this.widgets?.find(w => w.name === name);
|
||||
const src = source.widgets?.find(w => w.name === name);
|
||||
if (dst && src) dst.value = src.value;
|
||||
}
|
||||
};
|
||||
|
||||
nodeType.prototype._refreshKeys = async function () {
|
||||
const urlW = this.widgets?.find(w => w.name === "manager_url");
|
||||
const projW = this.widgets?.find(w => w.name === "project_name");
|
||||
const fileW = this.widgets?.find(w => w.name === "file_name");
|
||||
const seqW = this.widgets?.find(w => w.name === "sequence_number");
|
||||
if (!urlW?.value || !projW?.value || !fileW?.value) return;
|
||||
|
||||
try {
|
||||
const resp = await api.fetchApi(
|
||||
`/json_manager/get_project_keys?url=${encodeURIComponent(urlW.value)}&project=${encodeURIComponent(projW.value)}&file=${encodeURIComponent(fileW.value)}&seq=${seqW?.value || 1}`
|
||||
);
|
||||
if (!resp.ok) return;
|
||||
const data = await resp.json();
|
||||
if (data.error || !Array.isArray(data.keys)) return;
|
||||
|
||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||
if (keyWidget) {
|
||||
keyWidget.options.values = data.keys.length > 0 ? data.keys : [""];
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("[ProjectResolution] Failed to refresh keys:", e);
|
||||
}
|
||||
};
|
||||
|
||||
const origOnMouseDown = nodeType.prototype.onMouseDown;
|
||||
nodeType.prototype.onMouseDown = function (e, localPos, graphCanvas) {
|
||||
origOnMouseDown?.apply(this, arguments);
|
||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||
if (srcWidget) srcWidget.options.values = this._getSourceLabels();
|
||||
this._syncFromSource();
|
||||
};
|
||||
|
||||
const origOnConfigure = nodeType.prototype.onConfigure;
|
||||
nodeType.prototype.onConfigure = function (info) {
|
||||
origOnConfigure?.apply(this, arguments);
|
||||
this._configured = true;
|
||||
|
||||
for (const name of ["manager_url", "project_name", "file_name", "sequence_number"]) {
|
||||
const w = this.widgets?.find(w => w.name === name);
|
||||
if (w) hideWidget(w);
|
||||
}
|
||||
|
||||
const srcWidget = this.widgets?.find(w => w.name === "source_label");
|
||||
if (srcWidget && srcWidget.type !== "combo") {
|
||||
const node = this;
|
||||
replaceWithCombo(this, "source_label", this._getSourceLabels(), function (value) {
|
||||
node._syncFromSource();
|
||||
node._refreshKeys();
|
||||
});
|
||||
} else if (srcWidget) {
|
||||
srcWidget.options.values = this._getSourceLabels();
|
||||
}
|
||||
|
||||
const keyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||
if (keyWidget && keyWidget.type !== "combo") {
|
||||
const node = this;
|
||||
replaceWithCombo(this, "key_name", [], function (value) {
|
||||
node.title = value ? `Resolution: ${value}` : "Project Resolution";
|
||||
app.graph?.setDirtyCanvas(true, true);
|
||||
});
|
||||
}
|
||||
|
||||
const finalKeyWidget = this.widgets?.find(w => w.name === "key_name");
|
||||
if (finalKeyWidget?.value) {
|
||||
this.title = `Resolution: ${finalKeyWidget.value}`;
|
||||
}
|
||||
|
||||
this.setSize(this.computeSize());
|
||||
|
||||
const node = this;
|
||||
queueMicrotask(() => {
|
||||
node._syncFromSource();
|
||||
node._refreshKeys();
|
||||
});
|
||||
};
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
**Step 2: Run all tests**
|
||||
|
||||
```bash
|
||||
pytest tests/ -q
|
||||
```
|
||||
Expected: all tests PASS (JS has no Python tests)
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add web/project_resolution.js
|
||||
git commit -m "feat: ProjectResolution JS extension for ComfyUI frontend"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Final verification and push
|
||||
|
||||
**Step 1: Run full test suite**
|
||||
|
||||
```bash
|
||||
pytest tests/ -v
|
||||
```
|
||||
Expected: all tests PASS
|
||||
|
||||
**Step 2: Push branch**
|
||||
|
||||
```bash
|
||||
git push origin HEAD
|
||||
```
|
||||
@@ -295,12 +295,12 @@ def index():
|
||||
sync_to_db, pane_state.db, pane_state.current_project, fp, data)
|
||||
tree = data.get('history_tree')
|
||||
if tree and isinstance(tree, dict):
|
||||
for node in tree.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
for backup in data.get('history_tree_backup', []):
|
||||
if isinstance(backup, dict):
|
||||
for node in backup.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
pane_state.data_cache = data
|
||||
pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||
pane_state.loaded_file = str(fp)
|
||||
@@ -339,13 +339,13 @@ def index():
|
||||
sync_to_db, state.db, state.current_project, fp, data)
|
||||
tree = data.get('history_tree')
|
||||
if tree and isinstance(tree, dict):
|
||||
for node in tree.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
# Strip snapshot data from history_tree_backup to prevent RAM/disk bloat
|
||||
for backup in data.get('history_tree_backup', []):
|
||||
if isinstance(backup, dict):
|
||||
for node in backup.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
state.data_cache = data
|
||||
state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||
state.loaded_file = str(fp)
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
KEY_PROMPT_HISTORY = "prompt_history"
|
||||
|
||||
|
||||
class SnapshotTimeline:
|
||||
"""Flat chronological snapshot list — replaces the old HistoryTree DAG."""
|
||||
|
||||
def __init__(self, raw_data: dict[str, Any]) -> None:
|
||||
# Detect and migrate old HistoryTree format
|
||||
if "nodes" in raw_data and "branches" in raw_data:
|
||||
self._migrate_from_tree(raw_data)
|
||||
elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list):
|
||||
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
||||
else:
|
||||
self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {})
|
||||
self.current_id: str | None = raw_data.get("current_id", None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Migration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None:
|
||||
"""Flatten old HistoryTree nodes into snapshot list, discarding DAG info."""
|
||||
self.snapshots = {}
|
||||
nodes = raw_data.get("nodes", {})
|
||||
for nid, node in nodes.items():
|
||||
self.snapshots[nid] = {
|
||||
"id": nid,
|
||||
"timestamp": node.get("timestamp", time.time()),
|
||||
"note": node.get("note", "Migrated"),
|
||||
"pinned": False,
|
||||
"auto": False,
|
||||
"seq_count": self._count_seqs(node.get("data")),
|
||||
}
|
||||
# Preserve snapshot data if present
|
||||
if "data" in node and node["data"]:
|
||||
self.snapshots[nid]["data"] = node["data"]
|
||||
self.current_id = raw_data.get("head_id")
|
||||
|
||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||
"""Convert ancient prompt_history list into snapshots."""
|
||||
self.snapshots = {}
|
||||
self.current_id = None
|
||||
for item in reversed(old_list):
|
||||
sid = self._make_id()
|
||||
self.snapshots[sid] = {
|
||||
"id": sid,
|
||||
"timestamp": time.time(),
|
||||
"note": item.get("note", "Legacy Import"),
|
||||
"pinned": False,
|
||||
"auto": False,
|
||||
"seq_count": self._count_seqs(item),
|
||||
"data": item,
|
||||
}
|
||||
self.current_id = sid
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def record(self, data: dict[str, Any], note: str = "Snapshot",
|
||||
auto: bool = False) -> str:
|
||||
"""Create a new snapshot and return its ID."""
|
||||
sid = self._make_id()
|
||||
self.snapshots[sid] = {
|
||||
"id": sid,
|
||||
"timestamp": time.time(),
|
||||
"note": note,
|
||||
"pinned": False,
|
||||
"auto": auto,
|
||||
"seq_count": self._count_seqs(data),
|
||||
"data": data,
|
||||
}
|
||||
self.current_id = sid
|
||||
return sid
|
||||
|
||||
def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None:
|
||||
"""Return the inline snapshot data if present."""
|
||||
snap = self.snapshots.get(snapshot_id)
|
||||
if snap:
|
||||
return snap.get("data")
|
||||
return None
|
||||
|
||||
def toggle_pin(self, snapshot_id: str) -> bool:
|
||||
"""Toggle pinned state, return new value."""
|
||||
snap = self.snapshots.get(snapshot_id)
|
||||
if snap:
|
||||
snap["pinned"] = not snap.get("pinned", False)
|
||||
return snap["pinned"]
|
||||
return False
|
||||
|
||||
def delete(self, snapshot_id: str) -> None:
|
||||
"""Remove a snapshot."""
|
||||
self.snapshots.pop(snapshot_id, None)
|
||||
if self.current_id == snapshot_id:
|
||||
# Fall back to most recent remaining
|
||||
if self.snapshots:
|
||||
self.current_id = max(
|
||||
self.snapshots.values(), key=lambda s: s["timestamp"]
|
||||
)["id"]
|
||||
else:
|
||||
self.current_id = None
|
||||
|
||||
def strip_snapshots(self) -> None:
|
||||
"""Remove inline data from all snapshots (for slim JSON storage)."""
|
||||
for snap in self.snapshots.values():
|
||||
snap.pop("data", None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"snapshots": self.snapshots,
|
||||
"current_id": self.current_id,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _make_id(self) -> str:
|
||||
for _ in range(10):
|
||||
sid = str(uuid.uuid4())[:8]
|
||||
if sid not in self.snapshots:
|
||||
return sid
|
||||
raise ValueError("Failed to generate unique snapshot ID after 10 attempts")
|
||||
|
||||
@staticmethod
|
||||
def _count_seqs(data: dict | None) -> int:
|
||||
if not data:
|
||||
return 0
|
||||
from utils import KEY_BATCH_DATA
|
||||
batch = data.get(KEY_BATCH_DATA, [])
|
||||
return len(batch) if isinstance(batch, list) else 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Diff function
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]:
|
||||
"""Compare two batch lists by sequence_number, return per-sequence diffs.
|
||||
|
||||
Returns a list of dicts:
|
||||
{
|
||||
"seq_num": int,
|
||||
"status": "unchanged" | "changed" | "added" | "removed",
|
||||
"changes": [{"field": str, "old": Any, "new": Any}],
|
||||
}
|
||||
"""
|
||||
from utils import KEY_SEQUENCE_NUMBER
|
||||
|
||||
old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch}
|
||||
new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch}
|
||||
|
||||
all_seqs = sorted(set(old_by_seq) | set(new_by_seq))
|
||||
result = []
|
||||
|
||||
for seq_num in all_seqs:
|
||||
old_item = old_by_seq.get(seq_num)
|
||||
new_item = new_by_seq.get(seq_num)
|
||||
|
||||
if old_item and not new_item:
|
||||
result.append({"seq_num": seq_num, "status": "removed", "changes": []})
|
||||
elif new_item and not old_item:
|
||||
result.append({"seq_num": seq_num, "status": "added", "changes": []})
|
||||
else:
|
||||
# Both exist — field-by-field comparison
|
||||
all_keys = sorted(set(old_item) | set(new_item))
|
||||
changes = []
|
||||
for k in all_keys:
|
||||
old_val = old_item.get(k)
|
||||
new_val = new_item.get(k)
|
||||
if old_val != new_val:
|
||||
changes.append({"field": k, "old": old_val, "new": new_val})
|
||||
status = "changed" if changes else "unchanged"
|
||||
result.append({"seq_num": seq_num, "status": status, "changes": changes})
|
||||
|
||||
return result
|
||||
@@ -13,7 +13,7 @@ class AppState:
|
||||
snippets: dict = field(default_factory=dict)
|
||||
file_path: Path | None = None
|
||||
restored_indicator: str | None = None
|
||||
timeline_selected_nodes: set = field(default_factory=set)
|
||||
timeline_selected_id: str | None = None
|
||||
live_toggles: dict = field(default_factory=dict)
|
||||
show_comfy_monitor: bool = True
|
||||
|
||||
|
||||
+44
-25
@@ -16,9 +16,11 @@ from utils import (
|
||||
DEFAULTS, save_json, load_json, sync_to_db,
|
||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
||||
)
|
||||
from history_tree import HistoryTree
|
||||
from snapshot_timeline import SnapshotTimeline
|
||||
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
|
||||
_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots
|
||||
_last_auto_snap: dict[str, float] = {} # file_path -> timestamp
|
||||
SUB_SEGMENT_MULTIPLIER = 1000
|
||||
SUB_SEGMENT_NUM_COLORS = 6
|
||||
FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip']
|
||||
@@ -86,18 +88,18 @@ def find_insert_position(batch_list, parent_index, parent_seq_num):
|
||||
|
||||
# --- Auto change note ---
|
||||
|
||||
def _auto_change_note(htree, batch_list, state=None, file_path=None):
|
||||
def _auto_change_note(timeline, batch_list, state=None, file_path=None):
|
||||
"""Compare current batch_list against last snapshot and describe changes."""
|
||||
# Get previous batch data from the current head
|
||||
if not htree.head_id or htree.head_id not in htree.nodes:
|
||||
# Get previous batch data from the current snapshot
|
||||
if not timeline.current_id or timeline.current_id not in timeline.snapshots:
|
||||
return f'Initial save ({len(batch_list)} sequences)'
|
||||
|
||||
# Load previous snapshot from DB (nodes no longer hold data in memory)
|
||||
prev_data = htree.nodes[htree.head_id].get('data')
|
||||
# Load previous snapshot from inline data or DB
|
||||
prev_data = timeline.get_snapshot_data(timeline.current_id)
|
||||
if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path:
|
||||
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
||||
if df:
|
||||
prev_data = state.db.get_node_snapshot(df['id'], htree.head_id)
|
||||
prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id)
|
||||
prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, [])
|
||||
|
||||
prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch}
|
||||
@@ -363,38 +365,34 @@ def render_batch_processor(state: AppState):
|
||||
logger.info("save_and_snap START")
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||
htree = HistoryTree(tree_data)
|
||||
note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list, state=state, file_path=file_path)
|
||||
timeline = SnapshotTimeline(tree_data)
|
||||
note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path)
|
||||
# Single serialization: json roundtrip gives us an isolated snapshot
|
||||
# without the expensive deepcopy
|
||||
t1 = time.perf_counter()
|
||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snapshot_payload = json.loads(snapshot_json)
|
||||
logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1)
|
||||
try:
|
||||
htree.commit(snapshot_payload, note=note)
|
||||
timeline.record(snapshot_payload, note=note)
|
||||
except ValueError as e:
|
||||
ui.notify(f'Save failed: {e}', type='negative')
|
||||
return
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
# DB path: sync full tree (with snapshots) to DB, then
|
||||
# write slim tree (no snapshots) to JSON and memory
|
||||
full_tree = htree.to_dict()
|
||||
full_tree = timeline.to_dict()
|
||||
data[KEY_HISTORY_TREE] = full_tree
|
||||
t1 = time.perf_counter()
|
||||
db_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||
logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1)
|
||||
htree.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
timeline.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
t1 = time.perf_counter()
|
||||
slim_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
||||
else:
|
||||
# No DB: write full tree (with snapshots) to JSON
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
t1 = time.perf_counter()
|
||||
save_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||
@@ -416,9 +414,30 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
refresh_list):
|
||||
async def commit(message=None):
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
# Auto-snapshot with debounce
|
||||
fp_key = str(file_path)
|
||||
now = time.time()
|
||||
did_snap = False
|
||||
if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE:
|
||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||
snap_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snap_payload = json.loads(snap_json)
|
||||
try:
|
||||
timeline.record(snap_payload, note=message or "Auto-save", auto=True)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
db_snap = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap)
|
||||
timeline.strip_snapshots()
|
||||
did_snap = True
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
_last_auto_snap[fp_key] = now
|
||||
except ValueError:
|
||||
pass # Non-critical: skip auto-snapshot on ID collision
|
||||
snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
if state.db_enabled and state.current_project and state.db and not did_snap:
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||
if message:
|
||||
ui.notify(message, type='positive')
|
||||
@@ -845,26 +864,26 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
||||
batch_list[idx][key] = copy.deepcopy(source_seq.get(key))
|
||||
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
htree = HistoryTree(data.get(KEY_HISTORY_TREE, {}))
|
||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snapshot = json.loads(snapshot_json)
|
||||
try:
|
||||
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||
timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||
except ValueError as e:
|
||||
ui.notify(f'Mass update failed: {e}', type='negative')
|
||||
return
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
full_tree = htree.to_dict()
|
||||
full_tree = timeline.to_dict()
|
||||
data[KEY_HISTORY_TREE] = full_tree
|
||||
db_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||
htree.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
timeline.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
slim_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||
else:
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
save_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
||||
|
||||
+512
-589
File diff suppressed because it is too large
Load Diff
+3
-3
@@ -208,10 +208,10 @@ class TestHistoryTrees:
|
||||
def test_upsert_updates(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.save_history_tree(df_id, {"v": 1})
|
||||
db.save_history_tree(df_id, {"v": 2})
|
||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 1})
|
||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 2})
|
||||
result = db.get_history_tree(df_id)
|
||||
assert result == {"v": 2}
|
||||
assert result == {"snapshots": {}, "v": 2}
|
||||
|
||||
def test_get_nonexistent(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
import pytest
|
||||
from snapshot_timeline import SnapshotTimeline, diff_snapshots
|
||||
|
||||
|
||||
def test_record_creates_snapshot():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": [{"seed": 42}]}, note="first")
|
||||
assert sid in tl.snapshots
|
||||
assert tl.current_id == sid
|
||||
assert tl.snapshots[sid]["note"] == "first"
|
||||
assert tl.snapshots[sid]["auto"] is False
|
||||
assert tl.snapshots[sid]["seq_count"] == 1
|
||||
|
||||
|
||||
def test_record_auto_flag():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="auto save", auto=True)
|
||||
assert tl.snapshots[sid]["auto"] is True
|
||||
|
||||
|
||||
def test_multiple_records():
|
||||
tl = SnapshotTimeline({})
|
||||
id1 = tl.record({"batch_data": [{"a": 1}]}, note="one")
|
||||
id2 = tl.record({"batch_data": [{"b": 2}]}, note="two")
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id == id2
|
||||
|
||||
|
||||
def test_to_dict_roundtrip():
|
||||
tl = SnapshotTimeline({})
|
||||
tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||
d = tl.to_dict()
|
||||
tl2 = SnapshotTimeline(d)
|
||||
assert tl2.current_id == tl.current_id
|
||||
assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys())
|
||||
|
||||
|
||||
def test_migrate_from_history_tree():
|
||||
"""Old HistoryTree format should be flattened into snapshots."""
|
||||
old_data = {
|
||||
"nodes": {
|
||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}},
|
||||
"bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}},
|
||||
},
|
||||
"branches": {"main": "bbb"},
|
||||
"head_id": "bbb",
|
||||
}
|
||||
tl = SnapshotTimeline(old_data)
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id == "bbb"
|
||||
assert tl.snapshots["aaa"]["note"] == "First"
|
||||
assert tl.snapshots["bbb"]["note"] == "Second"
|
||||
# Data should be preserved
|
||||
assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}]
|
||||
|
||||
|
||||
def test_migrate_from_history_tree_no_data():
|
||||
"""Slim tree nodes (no inline data) should still migrate."""
|
||||
old_data = {
|
||||
"nodes": {
|
||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"},
|
||||
},
|
||||
"branches": {"main": "aaa"},
|
||||
"head_id": "aaa",
|
||||
}
|
||||
tl = SnapshotTimeline(old_data)
|
||||
assert len(tl.snapshots) == 1
|
||||
assert tl.snapshots["aaa"]["seq_count"] == 0
|
||||
|
||||
|
||||
def test_migrate_legacy_prompt_history():
|
||||
legacy = {
|
||||
"prompt_history": [
|
||||
{"note": "A", "seed": 1},
|
||||
{"note": "B", "seed": 2},
|
||||
]
|
||||
}
|
||||
tl = SnapshotTimeline(legacy)
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id is not None
|
||||
|
||||
|
||||
def test_toggle_pin():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="test")
|
||||
assert tl.snapshots[sid]["pinned"] is False
|
||||
result = tl.toggle_pin(sid)
|
||||
assert result is True
|
||||
assert tl.snapshots[sid]["pinned"] is True
|
||||
result = tl.toggle_pin(sid)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_delete_snapshot():
|
||||
tl = SnapshotTimeline({})
|
||||
id1 = tl.record({"batch_data": []}, note="one")
|
||||
id2 = tl.record({"batch_data": []}, note="two")
|
||||
tl.delete(id2)
|
||||
assert id2 not in tl.snapshots
|
||||
assert tl.current_id == id1
|
||||
|
||||
|
||||
def test_delete_all_snapshots():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="only")
|
||||
tl.delete(sid)
|
||||
assert len(tl.snapshots) == 0
|
||||
assert tl.current_id is None
|
||||
|
||||
|
||||
def test_strip_snapshots():
|
||||
tl = SnapshotTimeline({})
|
||||
tl.record({"batch_data": [{"a": 1}]}, note="test")
|
||||
tl.strip_snapshots()
|
||||
for snap in tl.snapshots.values():
|
||||
assert "data" not in snap
|
||||
|
||||
|
||||
def test_get_snapshot_data():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||
data = tl.get_snapshot_data(sid)
|
||||
assert data == {"batch_data": [{"x": 1}]}
|
||||
assert tl.get_snapshot_data("nonexistent") is None
|
||||
|
||||
|
||||
# --- diff_snapshots tests ---
|
||||
|
||||
def test_diff_unchanged():
|
||||
batch = [{"sequence_number": 1, "seed": 42}]
|
||||
result = diff_snapshots(batch, batch)
|
||||
assert len(result) == 1
|
||||
assert result[0]["status"] == "unchanged"
|
||||
assert result[0]["changes"] == []
|
||||
|
||||
|
||||
def test_diff_changed():
|
||||
old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}]
|
||||
new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}]
|
||||
result = diff_snapshots(old, new)
|
||||
assert result[0]["status"] == "changed"
|
||||
assert len(result[0]["changes"]) == 1
|
||||
assert result[0]["changes"][0]["field"] == "seed"
|
||||
assert result[0]["changes"][0]["old"] == 42
|
||||
assert result[0]["changes"][0]["new"] == 99
|
||||
|
||||
|
||||
def test_diff_added_and_removed():
|
||||
old = [{"sequence_number": 1, "seed": 1}]
|
||||
new = [{"sequence_number": 2, "seed": 2}]
|
||||
result = diff_snapshots(old, new)
|
||||
assert len(result) == 2
|
||||
statuses = {r["seq_num"]: r["status"] for r in result}
|
||||
assert statuses[1] == "removed"
|
||||
assert statuses[2] == "added"
|
||||
|
||||
|
||||
def test_diff_empty():
|
||||
assert diff_snapshots([], []) == []
|
||||
@@ -305,38 +305,47 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||
else:
|
||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||
|
||||
# Sync history tree (extract node snapshots into separate table)
|
||||
# Sync history tree (extract snapshot data into separate table)
|
||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
nodes = history_tree.get("nodes", {})
|
||||
# Detect format: new has "snapshots", old has "nodes"
|
||||
if "snapshots" in history_tree:
|
||||
entries = history_tree.get("snapshots", {})
|
||||
else:
|
||||
entries = history_tree.get("nodes", {})
|
||||
slim_tree = dict(history_tree)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(df_id, nid, json.dumps(snap), now),
|
||||
(df_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
# Write back slim version using the correct key
|
||||
if "snapshots" in history_tree:
|
||||
slim_tree["snapshots"] = slim_entries
|
||||
else:
|
||||
slim_tree["nodes"] = slim_entries
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(df_id, json.dumps(slim_tree), now),
|
||||
)
|
||||
# Clean up orphaned snapshots for nodes no longer in tree
|
||||
current_node_ids = set(nodes.keys())
|
||||
if current_node_ids:
|
||||
placeholders = ",".join("?" for _ in current_node_ids)
|
||||
# Clean up orphaned snapshots
|
||||
current_ids = set(entries.keys())
|
||||
if current_ids:
|
||||
placeholders = ",".join("?" for _ in current_ids)
|
||||
db.conn.execute(
|
||||
f"DELETE FROM history_snapshots WHERE data_file_id = ? "
|
||||
f"AND node_id NOT IN ({placeholders})",
|
||||
(df_id, *current_node_ids),
|
||||
(df_id, *current_ids),
|
||||
)
|
||||
else:
|
||||
db.conn.execute(
|
||||
|
||||
Reference in New Issue
Block a user