feat: replace Git-DAG timeline with flat snapshot browser
Replace HistoryTree (DAG with branches, Graphviz rendering) with a flat chronological SnapshotTimeline. New UI features: split-view layout, snapshot compare/diff, cherry-pick restore of individual sequences or fields, auto-snapshots with debounce, and pin/filter support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -242,7 +242,6 @@ class ProjectDB:
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
def _migrate_lora_keys(data: dict) -> dict:
|
||||
"""Split combined lora 'name:strength' into separate name and strength keys."""
|
||||
@@ -340,27 +339,34 @@ class ProjectDB:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_history_tree(self, data_file_id: int, tree_data: dict) -> None:
|
||||
"""Save history tree, extracting node snapshots into separate table."""
|
||||
"""Save history tree, extracting snapshot data into separate table.
|
||||
|
||||
Supports both new format (snapshots dict) and old format (nodes dict).
|
||||
"""
|
||||
now = time.time()
|
||||
nodes = tree_data.get("nodes", {})
|
||||
if "snapshots" in tree_data:
|
||||
entries = tree_data.get("snapshots", {})
|
||||
entry_key = "snapshots"
|
||||
else:
|
||||
entries = tree_data.get("nodes", {})
|
||||
entry_key = "nodes"
|
||||
slim_tree = dict(tree_data)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
slim_tree[entry_key] = slim_entries
|
||||
|
||||
self.conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
# Extract snapshot data from nodes into history_snapshots table
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(data_file_id, nid, json.dumps(snap), now),
|
||||
(data_file_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
@@ -463,24 +469,30 @@ class ProjectDB:
|
||||
)
|
||||
|
||||
# Import history tree (extract snapshots into separate table)
|
||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
now = time.time()
|
||||
nodes = history_tree.get("nodes", {})
|
||||
if "snapshots" in history_tree:
|
||||
entries = history_tree.get("snapshots", {})
|
||||
entry_key = "snapshots"
|
||||
else:
|
||||
entries = history_tree.get("nodes", {})
|
||||
entry_key = "nodes"
|
||||
slim_tree = dict(history_tree)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(df_id, nid, json.dumps(snap), now),
|
||||
(df_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
slim_tree[entry_key] = slim_entries
|
||||
self.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
@@ -540,9 +552,9 @@ class ProjectDB:
|
||||
# Load history tree (metadata only, no snapshot data)
|
||||
tree = self.get_history_tree(df["id"])
|
||||
if tree:
|
||||
# Strip any residual snapshot data from nodes
|
||||
for node in tree.get("nodes", {}).values():
|
||||
node.pop("data", None)
|
||||
# Strip any residual snapshot data (supports both formats)
|
||||
for entry in tree.get("snapshots", tree.get("nodes", {})).values():
|
||||
entry.pop("data", None)
|
||||
data["history_tree"] = tree
|
||||
t3 = time.time()
|
||||
|
||||
|
||||
@@ -295,12 +295,12 @@ def index():
|
||||
sync_to_db, pane_state.db, pane_state.current_project, fp, data)
|
||||
tree = data.get('history_tree')
|
||||
if tree and isinstance(tree, dict):
|
||||
for node in tree.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
for backup in data.get('history_tree_backup', []):
|
||||
if isinstance(backup, dict):
|
||||
for node in backup.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
pane_state.data_cache = data
|
||||
pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||
pane_state.loaded_file = str(fp)
|
||||
@@ -339,13 +339,13 @@ def index():
|
||||
sync_to_db, state.db, state.current_project, fp, data)
|
||||
tree = data.get('history_tree')
|
||||
if tree and isinstance(tree, dict):
|
||||
for node in tree.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in tree.get('snapshots', tree.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
# Strip snapshot data from history_tree_backup to prevent RAM/disk bloat
|
||||
for backup in data.get('history_tree_backup', []):
|
||||
if isinstance(backup, dict):
|
||||
for node in backup.get('nodes', {}).values():
|
||||
node.pop('data', None)
|
||||
for entry in backup.get('snapshots', backup.get('nodes', {})).values():
|
||||
entry.pop('data', None)
|
||||
state.data_cache = data
|
||||
state.last_mtime = fp.stat().st_mtime if fp.exists() else 0
|
||||
state.loaded_file = str(fp)
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
KEY_PROMPT_HISTORY = "prompt_history"
|
||||
|
||||
|
||||
class SnapshotTimeline:
|
||||
"""Flat chronological snapshot list — replaces the old HistoryTree DAG."""
|
||||
|
||||
def __init__(self, raw_data: dict[str, Any]) -> None:
|
||||
# Detect and migrate old HistoryTree format
|
||||
if "nodes" in raw_data and "branches" in raw_data:
|
||||
self._migrate_from_tree(raw_data)
|
||||
elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list):
|
||||
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
||||
else:
|
||||
self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {})
|
||||
self.current_id: str | None = raw_data.get("current_id", None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Migration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None:
|
||||
"""Flatten old HistoryTree nodes into snapshot list, discarding DAG info."""
|
||||
self.snapshots = {}
|
||||
nodes = raw_data.get("nodes", {})
|
||||
for nid, node in nodes.items():
|
||||
self.snapshots[nid] = {
|
||||
"id": nid,
|
||||
"timestamp": node.get("timestamp", time.time()),
|
||||
"note": node.get("note", "Migrated"),
|
||||
"pinned": False,
|
||||
"auto": False,
|
||||
"seq_count": self._count_seqs(node.get("data")),
|
||||
}
|
||||
# Preserve snapshot data if present
|
||||
if "data" in node and node["data"]:
|
||||
self.snapshots[nid]["data"] = node["data"]
|
||||
self.current_id = raw_data.get("head_id")
|
||||
|
||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||
"""Convert ancient prompt_history list into snapshots."""
|
||||
self.snapshots = {}
|
||||
self.current_id = None
|
||||
for item in reversed(old_list):
|
||||
sid = self._make_id()
|
||||
self.snapshots[sid] = {
|
||||
"id": sid,
|
||||
"timestamp": time.time(),
|
||||
"note": item.get("note", "Legacy Import"),
|
||||
"pinned": False,
|
||||
"auto": False,
|
||||
"seq_count": self._count_seqs(item),
|
||||
"data": item,
|
||||
}
|
||||
self.current_id = sid
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def record(self, data: dict[str, Any], note: str = "Snapshot",
|
||||
auto: bool = False) -> str:
|
||||
"""Create a new snapshot and return its ID."""
|
||||
sid = self._make_id()
|
||||
self.snapshots[sid] = {
|
||||
"id": sid,
|
||||
"timestamp": time.time(),
|
||||
"note": note,
|
||||
"pinned": False,
|
||||
"auto": auto,
|
||||
"seq_count": self._count_seqs(data),
|
||||
"data": data,
|
||||
}
|
||||
self.current_id = sid
|
||||
return sid
|
||||
|
||||
def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None:
|
||||
"""Return the inline snapshot data if present."""
|
||||
snap = self.snapshots.get(snapshot_id)
|
||||
if snap:
|
||||
return snap.get("data")
|
||||
return None
|
||||
|
||||
def toggle_pin(self, snapshot_id: str) -> bool:
|
||||
"""Toggle pinned state, return new value."""
|
||||
snap = self.snapshots.get(snapshot_id)
|
||||
if snap:
|
||||
snap["pinned"] = not snap.get("pinned", False)
|
||||
return snap["pinned"]
|
||||
return False
|
||||
|
||||
def delete(self, snapshot_id: str) -> None:
|
||||
"""Remove a snapshot."""
|
||||
self.snapshots.pop(snapshot_id, None)
|
||||
if self.current_id == snapshot_id:
|
||||
# Fall back to most recent remaining
|
||||
if self.snapshots:
|
||||
self.current_id = max(
|
||||
self.snapshots.values(), key=lambda s: s["timestamp"]
|
||||
)["id"]
|
||||
else:
|
||||
self.current_id = None
|
||||
|
||||
def strip_snapshots(self) -> None:
|
||||
"""Remove inline data from all snapshots (for slim JSON storage)."""
|
||||
for snap in self.snapshots.values():
|
||||
snap.pop("data", None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"snapshots": self.snapshots,
|
||||
"current_id": self.current_id,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _make_id(self) -> str:
|
||||
for _ in range(10):
|
||||
sid = str(uuid.uuid4())[:8]
|
||||
if sid not in self.snapshots:
|
||||
return sid
|
||||
raise ValueError("Failed to generate unique snapshot ID after 10 attempts")
|
||||
|
||||
@staticmethod
|
||||
def _count_seqs(data: dict | None) -> int:
|
||||
if not data:
|
||||
return 0
|
||||
from utils import KEY_BATCH_DATA
|
||||
batch = data.get(KEY_BATCH_DATA, [])
|
||||
return len(batch) if isinstance(batch, list) else 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Diff function
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]:
|
||||
"""Compare two batch lists by sequence_number, return per-sequence diffs.
|
||||
|
||||
Returns a list of dicts:
|
||||
{
|
||||
"seq_num": int,
|
||||
"status": "unchanged" | "changed" | "added" | "removed",
|
||||
"changes": [{"field": str, "old": Any, "new": Any}],
|
||||
}
|
||||
"""
|
||||
from utils import KEY_SEQUENCE_NUMBER
|
||||
|
||||
old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch}
|
||||
new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch}
|
||||
|
||||
all_seqs = sorted(set(old_by_seq) | set(new_by_seq))
|
||||
result = []
|
||||
|
||||
for seq_num in all_seqs:
|
||||
old_item = old_by_seq.get(seq_num)
|
||||
new_item = new_by_seq.get(seq_num)
|
||||
|
||||
if old_item and not new_item:
|
||||
result.append({"seq_num": seq_num, "status": "removed", "changes": []})
|
||||
elif new_item and not old_item:
|
||||
result.append({"seq_num": seq_num, "status": "added", "changes": []})
|
||||
else:
|
||||
# Both exist — field-by-field comparison
|
||||
all_keys = sorted(set(old_item) | set(new_item))
|
||||
changes = []
|
||||
for k in all_keys:
|
||||
old_val = old_item.get(k)
|
||||
new_val = new_item.get(k)
|
||||
if old_val != new_val:
|
||||
changes.append({"field": k, "old": old_val, "new": new_val})
|
||||
status = "changed" if changes else "unchanged"
|
||||
result.append({"seq_num": seq_num, "status": status, "changes": changes})
|
||||
|
||||
return result
|
||||
@@ -13,7 +13,7 @@ class AppState:
|
||||
snippets: dict = field(default_factory=dict)
|
||||
file_path: Path | None = None
|
||||
restored_indicator: str | None = None
|
||||
timeline_selected_nodes: set = field(default_factory=set)
|
||||
timeline_selected_id: str | None = None
|
||||
live_toggles: dict = field(default_factory=dict)
|
||||
show_comfy_monitor: bool = True
|
||||
|
||||
|
||||
+44
-25
@@ -16,9 +16,11 @@ from utils import (
|
||||
DEFAULTS, save_json, load_json, sync_to_db,
|
||||
KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER,
|
||||
)
|
||||
from history_tree import HistoryTree
|
||||
from snapshot_timeline import SnapshotTimeline
|
||||
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'}
|
||||
_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots
|
||||
_last_auto_snap: dict[str, float] = {} # file_path -> timestamp
|
||||
SUB_SEGMENT_MULTIPLIER = 1000
|
||||
SUB_SEGMENT_NUM_COLORS = 6
|
||||
FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip']
|
||||
@@ -86,18 +88,18 @@ def find_insert_position(batch_list, parent_index, parent_seq_num):
|
||||
|
||||
# --- Auto change note ---
|
||||
|
||||
def _auto_change_note(htree, batch_list, state=None, file_path=None):
|
||||
def _auto_change_note(timeline, batch_list, state=None, file_path=None):
|
||||
"""Compare current batch_list against last snapshot and describe changes."""
|
||||
# Get previous batch data from the current head
|
||||
if not htree.head_id or htree.head_id not in htree.nodes:
|
||||
# Get previous batch data from the current snapshot
|
||||
if not timeline.current_id or timeline.current_id not in timeline.snapshots:
|
||||
return f'Initial save ({len(batch_list)} sequences)'
|
||||
|
||||
# Load previous snapshot from DB (nodes no longer hold data in memory)
|
||||
prev_data = htree.nodes[htree.head_id].get('data')
|
||||
# Load previous snapshot from inline data or DB
|
||||
prev_data = timeline.get_snapshot_data(timeline.current_id)
|
||||
if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path:
|
||||
df = state.db.get_data_file_by_names(state.current_project, file_path.stem)
|
||||
if df:
|
||||
prev_data = state.db.get_node_snapshot(df['id'], htree.head_id)
|
||||
prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id)
|
||||
prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, [])
|
||||
|
||||
prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch}
|
||||
@@ -363,38 +365,34 @@ def render_batch_processor(state: AppState):
|
||||
logger.info("save_and_snap START")
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||
htree = HistoryTree(tree_data)
|
||||
note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list, state=state, file_path=file_path)
|
||||
timeline = SnapshotTimeline(tree_data)
|
||||
note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path)
|
||||
# Single serialization: json roundtrip gives us an isolated snapshot
|
||||
# without the expensive deepcopy
|
||||
t1 = time.perf_counter()
|
||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snapshot_payload = json.loads(snapshot_json)
|
||||
logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1)
|
||||
try:
|
||||
htree.commit(snapshot_payload, note=note)
|
||||
timeline.record(snapshot_payload, note=note)
|
||||
except ValueError as e:
|
||||
ui.notify(f'Save failed: {e}', type='negative')
|
||||
return
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
# DB path: sync full tree (with snapshots) to DB, then
|
||||
# write slim tree (no snapshots) to JSON and memory
|
||||
full_tree = htree.to_dict()
|
||||
full_tree = timeline.to_dict()
|
||||
data[KEY_HISTORY_TREE] = full_tree
|
||||
t1 = time.perf_counter()
|
||||
db_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||
logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1)
|
||||
htree.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
timeline.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
t1 = time.perf_counter()
|
||||
slim_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||
logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1)
|
||||
else:
|
||||
# No DB: write full tree (with snapshots) to JSON
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
t1 = time.perf_counter()
|
||||
save_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||
@@ -416,9 +414,30 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state,
|
||||
refresh_list):
|
||||
async def commit(message=None):
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
# Auto-snapshot with debounce
|
||||
fp_key = str(file_path)
|
||||
now = time.time()
|
||||
did_snap = False
|
||||
if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE:
|
||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||
snap_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snap_payload = json.loads(snap_json)
|
||||
try:
|
||||
timeline.record(snap_payload, note=message or "Auto-save", auto=True)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
db_snap = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap)
|
||||
timeline.strip_snapshots()
|
||||
did_snap = True
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
_last_auto_snap[fp_key] = now
|
||||
except ValueError:
|
||||
pass # Non-critical: skip auto-snapshot on ID collision
|
||||
snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, snapshot)
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
if state.db_enabled and state.current_project and state.db and not did_snap:
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot)
|
||||
if message:
|
||||
ui.notify(message, type='positive')
|
||||
@@ -845,26 +864,26 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
|
||||
batch_list[idx][key] = copy.deepcopy(source_seq.get(key))
|
||||
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
htree = HistoryTree(data.get(KEY_HISTORY_TREE, {}))
|
||||
timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {}))
|
||||
snapshot_json = json.dumps({k: v for k, v in data.items()
|
||||
if k != KEY_HISTORY_TREE})
|
||||
snapshot = json.loads(snapshot_json)
|
||||
try:
|
||||
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||
timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}")
|
||||
except ValueError as e:
|
||||
ui.notify(f'Mass update failed: {e}', type='negative')
|
||||
return
|
||||
if state.db_enabled and state.current_project and state.db:
|
||||
full_tree = htree.to_dict()
|
||||
full_tree = timeline.to_dict()
|
||||
data[KEY_HISTORY_TREE] = full_tree
|
||||
db_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot)
|
||||
htree.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
timeline.strip_snapshots()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
slim_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, slim_snapshot)
|
||||
else:
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = timeline.to_dict()
|
||||
save_snapshot = json.loads(json.dumps(data))
|
||||
await asyncio.to_thread(save_json, file_path, save_snapshot)
|
||||
ui.notify(f'Updated {len(targets)} sequences', type='positive')
|
||||
|
||||
+512
-589
File diff suppressed because it is too large
Load Diff
+3
-3
@@ -208,10 +208,10 @@ class TestHistoryTrees:
|
||||
def test_upsert_updates(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
df_id = db.create_data_file(pid, "batch", "generic")
|
||||
db.save_history_tree(df_id, {"v": 1})
|
||||
db.save_history_tree(df_id, {"v": 2})
|
||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 1})
|
||||
db.save_history_tree(df_id, {"snapshots": {}, "v": 2})
|
||||
result = db.get_history_tree(df_id)
|
||||
assert result == {"v": 2}
|
||||
assert result == {"snapshots": {}, "v": 2}
|
||||
|
||||
def test_get_nonexistent(self, db):
|
||||
pid = db.create_project("p1", "/p1")
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
import pytest
|
||||
from snapshot_timeline import SnapshotTimeline, diff_snapshots
|
||||
|
||||
|
||||
def test_record_creates_snapshot():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": [{"seed": 42}]}, note="first")
|
||||
assert sid in tl.snapshots
|
||||
assert tl.current_id == sid
|
||||
assert tl.snapshots[sid]["note"] == "first"
|
||||
assert tl.snapshots[sid]["auto"] is False
|
||||
assert tl.snapshots[sid]["seq_count"] == 1
|
||||
|
||||
|
||||
def test_record_auto_flag():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="auto save", auto=True)
|
||||
assert tl.snapshots[sid]["auto"] is True
|
||||
|
||||
|
||||
def test_multiple_records():
|
||||
tl = SnapshotTimeline({})
|
||||
id1 = tl.record({"batch_data": [{"a": 1}]}, note="one")
|
||||
id2 = tl.record({"batch_data": [{"b": 2}]}, note="two")
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id == id2
|
||||
|
||||
|
||||
def test_to_dict_roundtrip():
|
||||
tl = SnapshotTimeline({})
|
||||
tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||
d = tl.to_dict()
|
||||
tl2 = SnapshotTimeline(d)
|
||||
assert tl2.current_id == tl.current_id
|
||||
assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys())
|
||||
|
||||
|
||||
def test_migrate_from_history_tree():
|
||||
"""Old HistoryTree format should be flattened into snapshots."""
|
||||
old_data = {
|
||||
"nodes": {
|
||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}},
|
||||
"bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}},
|
||||
},
|
||||
"branches": {"main": "bbb"},
|
||||
"head_id": "bbb",
|
||||
}
|
||||
tl = SnapshotTimeline(old_data)
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id == "bbb"
|
||||
assert tl.snapshots["aaa"]["note"] == "First"
|
||||
assert tl.snapshots["bbb"]["note"] == "Second"
|
||||
# Data should be preserved
|
||||
assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}]
|
||||
|
||||
|
||||
def test_migrate_from_history_tree_no_data():
|
||||
"""Slim tree nodes (no inline data) should still migrate."""
|
||||
old_data = {
|
||||
"nodes": {
|
||||
"aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"},
|
||||
},
|
||||
"branches": {"main": "aaa"},
|
||||
"head_id": "aaa",
|
||||
}
|
||||
tl = SnapshotTimeline(old_data)
|
||||
assert len(tl.snapshots) == 1
|
||||
assert tl.snapshots["aaa"]["seq_count"] == 0
|
||||
|
||||
|
||||
def test_migrate_legacy_prompt_history():
|
||||
legacy = {
|
||||
"prompt_history": [
|
||||
{"note": "A", "seed": 1},
|
||||
{"note": "B", "seed": 2},
|
||||
]
|
||||
}
|
||||
tl = SnapshotTimeline(legacy)
|
||||
assert len(tl.snapshots) == 2
|
||||
assert tl.current_id is not None
|
||||
|
||||
|
||||
def test_toggle_pin():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="test")
|
||||
assert tl.snapshots[sid]["pinned"] is False
|
||||
result = tl.toggle_pin(sid)
|
||||
assert result is True
|
||||
assert tl.snapshots[sid]["pinned"] is True
|
||||
result = tl.toggle_pin(sid)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_delete_snapshot():
|
||||
tl = SnapshotTimeline({})
|
||||
id1 = tl.record({"batch_data": []}, note="one")
|
||||
id2 = tl.record({"batch_data": []}, note="two")
|
||||
tl.delete(id2)
|
||||
assert id2 not in tl.snapshots
|
||||
assert tl.current_id == id1
|
||||
|
||||
|
||||
def test_delete_all_snapshots():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": []}, note="only")
|
||||
tl.delete(sid)
|
||||
assert len(tl.snapshots) == 0
|
||||
assert tl.current_id is None
|
||||
|
||||
|
||||
def test_strip_snapshots():
|
||||
tl = SnapshotTimeline({})
|
||||
tl.record({"batch_data": [{"a": 1}]}, note="test")
|
||||
tl.strip_snapshots()
|
||||
for snap in tl.snapshots.values():
|
||||
assert "data" not in snap
|
||||
|
||||
|
||||
def test_get_snapshot_data():
|
||||
tl = SnapshotTimeline({})
|
||||
sid = tl.record({"batch_data": [{"x": 1}]}, note="test")
|
||||
data = tl.get_snapshot_data(sid)
|
||||
assert data == {"batch_data": [{"x": 1}]}
|
||||
assert tl.get_snapshot_data("nonexistent") is None
|
||||
|
||||
|
||||
# --- diff_snapshots tests ---
|
||||
|
||||
def test_diff_unchanged():
|
||||
batch = [{"sequence_number": 1, "seed": 42}]
|
||||
result = diff_snapshots(batch, batch)
|
||||
assert len(result) == 1
|
||||
assert result[0]["status"] == "unchanged"
|
||||
assert result[0]["changes"] == []
|
||||
|
||||
|
||||
def test_diff_changed():
|
||||
old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}]
|
||||
new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}]
|
||||
result = diff_snapshots(old, new)
|
||||
assert result[0]["status"] == "changed"
|
||||
assert len(result[0]["changes"]) == 1
|
||||
assert result[0]["changes"][0]["field"] == "seed"
|
||||
assert result[0]["changes"][0]["old"] == 42
|
||||
assert result[0]["changes"][0]["new"] == 99
|
||||
|
||||
|
||||
def test_diff_added_and_removed():
|
||||
old = [{"sequence_number": 1, "seed": 1}]
|
||||
new = [{"sequence_number": 2, "seed": 2}]
|
||||
result = diff_snapshots(old, new)
|
||||
assert len(result) == 2
|
||||
statuses = {r["seq_num"]: r["status"] for r in result}
|
||||
assert statuses[1] == "removed"
|
||||
assert statuses[2] == "added"
|
||||
|
||||
|
||||
def test_diff_empty():
|
||||
assert diff_snapshots([], []) == []
|
||||
@@ -305,38 +305,47 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None:
|
||||
else:
|
||||
db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,))
|
||||
|
||||
# Sync history tree (extract node snapshots into separate table)
|
||||
# Sync history tree (extract snapshot data into separate table)
|
||||
# Supports both new format (snapshots dict) and old format (nodes dict)
|
||||
history_tree = data.get(KEY_HISTORY_TREE)
|
||||
if history_tree and isinstance(history_tree, dict):
|
||||
nodes = history_tree.get("nodes", {})
|
||||
# Detect format: new has "snapshots", old has "nodes"
|
||||
if "snapshots" in history_tree:
|
||||
entries = history_tree.get("snapshots", {})
|
||||
else:
|
||||
entries = history_tree.get("nodes", {})
|
||||
slim_tree = dict(history_tree)
|
||||
slim_nodes = {}
|
||||
for nid, node in nodes.items():
|
||||
snap = node.get("data")
|
||||
slim_entries = {}
|
||||
for eid, entry in entries.items():
|
||||
snap = entry.get("data")
|
||||
if snap:
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) "
|
||||
"VALUES (?, ?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id, node_id) DO UPDATE SET "
|
||||
"snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at",
|
||||
(df_id, nid, json.dumps(snap), now),
|
||||
(df_id, eid, json.dumps(snap), now),
|
||||
)
|
||||
slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"}
|
||||
slim_tree["nodes"] = slim_nodes
|
||||
slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"}
|
||||
# Write back slim version using the correct key
|
||||
if "snapshots" in history_tree:
|
||||
slim_tree["snapshots"] = slim_entries
|
||||
else:
|
||||
slim_tree["nodes"] = slim_entries
|
||||
db.conn.execute(
|
||||
"INSERT INTO history_trees (data_file_id, tree_data, updated_at) "
|
||||
"VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at",
|
||||
(df_id, json.dumps(slim_tree), now),
|
||||
)
|
||||
# Clean up orphaned snapshots for nodes no longer in tree
|
||||
current_node_ids = set(nodes.keys())
|
||||
if current_node_ids:
|
||||
placeholders = ",".join("?" for _ in current_node_ids)
|
||||
# Clean up orphaned snapshots
|
||||
current_ids = set(entries.keys())
|
||||
if current_ids:
|
||||
placeholders = ",".join("?" for _ in current_ids)
|
||||
db.conn.execute(
|
||||
f"DELETE FROM history_snapshots WHERE data_file_id = ? "
|
||||
f"AND node_id NOT IN ({placeholders})",
|
||||
(df_id, *current_node_ids),
|
||||
(df_id, *current_ids),
|
||||
)
|
||||
else:
|
||||
db.conn.execute(
|
||||
|
||||
Reference in New Issue
Block a user