diff --git a/db.py b/db.py index c870e8f..81e9461 100644 --- a/db.py +++ b/db.py @@ -242,7 +242,6 @@ class ProjectDB: ) self.conn.commit() - @staticmethod @staticmethod def _migrate_lora_keys(data: dict) -> dict: """Split combined lora 'name:strength' into separate name and strength keys.""" @@ -340,27 +339,34 @@ class ProjectDB: # ------------------------------------------------------------------ def save_history_tree(self, data_file_id: int, tree_data: dict) -> None: - """Save history tree, extracting node snapshots into separate table.""" + """Save history tree, extracting snapshot data into separate table. + + Supports both new format (snapshots dict) and old format (nodes dict). + """ now = time.time() - nodes = tree_data.get("nodes", {}) + if "snapshots" in tree_data: + entries = tree_data.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = tree_data.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(tree_data) - slim_nodes = {} - for nid, node in nodes.items(): - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries = {} + for eid, entry in entries.items(): + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute("BEGIN IMMEDIATE") try: - # Extract snapshot data from nodes into history_snapshots table - for nid, node in nodes.items(): - snap = node.get("data") + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (data_file_id, nid, json.dumps(snap), now), + (data_file_id, eid, json.dumps(snap), now), ) self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " @@ -463,24 +469,30 @@ class ProjectDB: ) # Import history tree (extract snapshots into separate table) + # Supports both new format (snapshots dict) and old format (nodes dict) history_tree = data.get(KEY_HISTORY_TREE) if history_tree and isinstance(history_tree, dict): now = time.time() - nodes = history_tree.get("nodes", {}) + if "snapshots" in history_tree: + entries = history_tree.get("snapshots", {}) + entry_key = "snapshots" + else: + entries = history_tree.get("nodes", {}) + entry_key = "nodes" slim_tree = dict(history_tree) - slim_nodes = {} - for nid, node in nodes.items(): - snap = node.get("data") + slim_entries = {} + for eid, entry in entries.items(): + snap = entry.get("data") if snap: self.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (df_id, nid, json.dumps(snap), now), + (df_id, eid, json.dumps(snap), now), ) - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + slim_tree[entry_key] = slim_entries self.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " "VALUES (?, ?, ?) " @@ -540,9 +552,9 @@ class ProjectDB: # Load history tree (metadata only, no snapshot data) tree = self.get_history_tree(df["id"]) if tree: - # Strip any residual snapshot data from nodes - for node in tree.get("nodes", {}).values(): - node.pop("data", None) + # Strip any residual snapshot data (supports both formats) + for entry in tree.get("snapshots", tree.get("nodes", {})).values(): + entry.pop("data", None) data["history_tree"] = tree t3 = time.time() diff --git a/main.py b/main.py index d0ba139..3ce2e93 100644 --- a/main.py +++ b/main.py @@ -295,12 +295,12 @@ def index(): sync_to_db, pane_state.db, pane_state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) pane_state.data_cache = data pane_state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 pane_state.loaded_file = str(fp) @@ -339,13 +339,13 @@ def index(): sync_to_db, state.db, state.current_project, fp, data) tree = data.get('history_tree') if tree and isinstance(tree, dict): - for node in tree.get('nodes', {}).values(): - node.pop('data', None) + for entry in tree.get('snapshots', tree.get('nodes', {})).values(): + entry.pop('data', None) # Strip snapshot data from history_tree_backup to prevent RAM/disk bloat for backup in data.get('history_tree_backup', []): if isinstance(backup, dict): - for node in backup.get('nodes', {}).values(): - node.pop('data', None) + for entry in backup.get('snapshots', backup.get('nodes', {})).values(): + entry.pop('data', None) state.data_cache = data state.last_mtime = fp.stat().st_mtime if fp.exists() else 0 state.loaded_file = str(fp) diff --git a/snapshot_timeline.py b/snapshot_timeline.py new file mode 100644 index 0000000..2a8e312 --- /dev/null +++ b/snapshot_timeline.py @@ -0,0 +1,184 @@ +import time +import uuid +from typing import Any + +KEY_PROMPT_HISTORY = "prompt_history" + + +class SnapshotTimeline: + """Flat chronological snapshot list — replaces the old HistoryTree DAG.""" + + def __init__(self, raw_data: dict[str, Any]) -> None: + # Detect and migrate old HistoryTree format + if "nodes" in raw_data and "branches" in raw_data: + self._migrate_from_tree(raw_data) + elif KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list): + self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY]) + else: + self.snapshots: dict[str, dict[str, Any]] = raw_data.get("snapshots", {}) + self.current_id: str | None = raw_data.get("current_id", None) + + # ------------------------------------------------------------------ + # Migration + # ------------------------------------------------------------------ + + def _migrate_from_tree(self, raw_data: dict[str, Any]) -> None: + """Flatten old HistoryTree nodes into snapshot list, discarding DAG info.""" + self.snapshots = {} + nodes = raw_data.get("nodes", {}) + for nid, node in nodes.items(): + self.snapshots[nid] = { + "id": nid, + "timestamp": node.get("timestamp", time.time()), + "note": node.get("note", "Migrated"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(node.get("data")), + } + # Preserve snapshot data if present + if "data" in node and node["data"]: + self.snapshots[nid]["data"] = node["data"] + self.current_id = raw_data.get("head_id") + + def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None: + """Convert ancient prompt_history list into snapshots.""" + self.snapshots = {} + self.current_id = None + for item in reversed(old_list): + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": item.get("note", "Legacy Import"), + "pinned": False, + "auto": False, + "seq_count": self._count_seqs(item), + "data": item, + } + self.current_id = sid + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def record(self, data: dict[str, Any], note: str = "Snapshot", + auto: bool = False) -> str: + """Create a new snapshot and return its ID.""" + sid = self._make_id() + self.snapshots[sid] = { + "id": sid, + "timestamp": time.time(), + "note": note, + "pinned": False, + "auto": auto, + "seq_count": self._count_seqs(data), + "data": data, + } + self.current_id = sid + return sid + + def get_snapshot_data(self, snapshot_id: str) -> dict[str, Any] | None: + """Return the inline snapshot data if present.""" + snap = self.snapshots.get(snapshot_id) + if snap: + return snap.get("data") + return None + + def toggle_pin(self, snapshot_id: str) -> bool: + """Toggle pinned state, return new value.""" + snap = self.snapshots.get(snapshot_id) + if snap: + snap["pinned"] = not snap.get("pinned", False) + return snap["pinned"] + return False + + def delete(self, snapshot_id: str) -> None: + """Remove a snapshot.""" + self.snapshots.pop(snapshot_id, None) + if self.current_id == snapshot_id: + # Fall back to most recent remaining + if self.snapshots: + self.current_id = max( + self.snapshots.values(), key=lambda s: s["timestamp"] + )["id"] + else: + self.current_id = None + + def strip_snapshots(self) -> None: + """Remove inline data from all snapshots (for slim JSON storage).""" + for snap in self.snapshots.values(): + snap.pop("data", None) + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + return { + "snapshots": self.snapshots, + "current_id": self.current_id, + } + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _make_id(self) -> str: + for _ in range(10): + sid = str(uuid.uuid4())[:8] + if sid not in self.snapshots: + return sid + raise ValueError("Failed to generate unique snapshot ID after 10 attempts") + + @staticmethod + def _count_seqs(data: dict | None) -> int: + if not data: + return 0 + from utils import KEY_BATCH_DATA + batch = data.get(KEY_BATCH_DATA, []) + return len(batch) if isinstance(batch, list) else 0 + + +# ------------------------------------------------------------------ +# Diff function +# ------------------------------------------------------------------ + +def diff_snapshots(old_batch: list[dict], new_batch: list[dict]) -> list[dict]: + """Compare two batch lists by sequence_number, return per-sequence diffs. + + Returns a list of dicts: + { + "seq_num": int, + "status": "unchanged" | "changed" | "added" | "removed", + "changes": [{"field": str, "old": Any, "new": Any}], + } + """ + from utils import KEY_SEQUENCE_NUMBER + + old_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in old_batch} + new_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in new_batch} + + all_seqs = sorted(set(old_by_seq) | set(new_by_seq)) + result = [] + + for seq_num in all_seqs: + old_item = old_by_seq.get(seq_num) + new_item = new_by_seq.get(seq_num) + + if old_item and not new_item: + result.append({"seq_num": seq_num, "status": "removed", "changes": []}) + elif new_item and not old_item: + result.append({"seq_num": seq_num, "status": "added", "changes": []}) + else: + # Both exist — field-by-field comparison + all_keys = sorted(set(old_item) | set(new_item)) + changes = [] + for k in all_keys: + old_val = old_item.get(k) + new_val = new_item.get(k) + if old_val != new_val: + changes.append({"field": k, "old": old_val, "new": new_val}) + status = "changed" if changes else "unchanged" + result.append({"seq_num": seq_num, "status": status, "changes": changes}) + + return result diff --git a/state.py b/state.py index 4f8d7a4..86f236a 100644 --- a/state.py +++ b/state.py @@ -13,7 +13,7 @@ class AppState: snippets: dict = field(default_factory=dict) file_path: Path | None = None restored_indicator: str | None = None - timeline_selected_nodes: set = field(default_factory=set) + timeline_selected_id: str | None = None live_toggles: dict = field(default_factory=dict) show_comfy_monitor: bool = True diff --git a/tab_batch_ng.py b/tab_batch_ng.py index 7bf4a59..c0bd53c 100644 --- a/tab_batch_ng.py +++ b/tab_batch_ng.py @@ -16,9 +16,11 @@ from utils import ( DEFAULTS, save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, ) -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.webp', '.bmp', '.gif'} +_AUTO_SNAP_DEBOUNCE = 30 # seconds between auto-snapshots +_last_auto_snap: dict[str, float] = {} # file_path -> timestamp SUB_SEGMENT_MULTIPLIER = 1000 SUB_SEGMENT_NUM_COLORS = 6 FRAME_TO_SKIP_DEFAULT = DEFAULTS['frame_to_skip'] @@ -86,18 +88,18 @@ def find_insert_position(batch_list, parent_index, parent_seq_num): # --- Auto change note --- -def _auto_change_note(htree, batch_list, state=None, file_path=None): +def _auto_change_note(timeline, batch_list, state=None, file_path=None): """Compare current batch_list against last snapshot and describe changes.""" - # Get previous batch data from the current head - if not htree.head_id or htree.head_id not in htree.nodes: + # Get previous batch data from the current snapshot + if not timeline.current_id or timeline.current_id not in timeline.snapshots: return f'Initial save ({len(batch_list)} sequences)' - # Load previous snapshot from DB (nodes no longer hold data in memory) - prev_data = htree.nodes[htree.head_id].get('data') + # Load previous snapshot from inline data or DB + prev_data = timeline.get_snapshot_data(timeline.current_id) if not prev_data and state and state.db_enabled and state.db and state.current_project and file_path: df = state.db.get_data_file_by_names(state.current_project, file_path.stem) if df: - prev_data = state.db.get_node_snapshot(df['id'], htree.head_id) + prev_data = state.db.get_node_snapshot(df['id'], timeline.current_id) prev_batch = (prev_data or {}).get(KEY_BATCH_DATA, []) prev_by_seq = {int(s.get(KEY_SEQUENCE_NUMBER, 0)): s for s in prev_batch} @@ -363,38 +365,34 @@ def render_batch_processor(state: AppState): logger.info("save_and_snap START") data[KEY_BATCH_DATA] = batch_list tree_data = data.get(KEY_HISTORY_TREE, {}) - htree = HistoryTree(tree_data) - note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list, state=state, file_path=file_path) + timeline = SnapshotTimeline(tree_data) + note = commit_input.value if commit_input.value else _auto_change_note(timeline, batch_list, state=state, file_path=file_path) # Single serialization: json roundtrip gives us an isolated snapshot - # without the expensive deepcopy t1 = time.perf_counter() snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot_payload = json.loads(snapshot_json) logger.info("save_and_snap snapshot %.3fs", time.perf_counter() - t1) try: - htree.commit(snapshot_payload, note=note) + timeline.record(snapshot_payload, note=note) except ValueError as e: ui.notify(f'Save failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - # DB path: sync full tree (with snapshots) to DB, then - # write slim tree (no snapshots) to JSON and memory - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree t1 = time.perf_counter() db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) logger.info("save_and_snap sync_to_db %.3fs", time.perf_counter() - t1) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) logger.info("save_and_snap save_json %.3fs", time.perf_counter() - t1) else: - # No DB: write full tree (with snapshots) to JSON - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() t1 = time.perf_counter() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) @@ -416,9 +414,30 @@ def _render_sequence_card(i, seq, batch_list, data, file_path, state, refresh_list): async def commit(message=None): data[KEY_BATCH_DATA] = batch_list + # Auto-snapshot with debounce + fp_key = str(file_path) + now = time.time() + did_snap = False + if now - _last_auto_snap.get(fp_key, 0) >= _AUTO_SNAP_DEBOUNCE: + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) + snap_json = json.dumps({k: v for k, v in data.items() + if k != KEY_HISTORY_TREE}) + snap_payload = json.loads(snap_json) + try: + timeline.record(snap_payload, note=message or "Auto-save", auto=True) + if state.db_enabled and state.current_project and state.db: + data[KEY_HISTORY_TREE] = timeline.to_dict() + db_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap) + timeline.strip_snapshots() + did_snap = True + data[KEY_HISTORY_TREE] = timeline.to_dict() + _last_auto_snap[fp_key] = now + except ValueError: + pass # Non-critical: skip auto-snapshot on ID collision snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: + if state.db_enabled and state.current_project and state.db and not did_snap: await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) if message: ui.notify(message, type='positive') @@ -845,26 +864,26 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li batch_list[idx][key] = copy.deepcopy(source_seq.get(key)) data[KEY_BATCH_DATA] = batch_list - htree = HistoryTree(data.get(KEY_HISTORY_TREE, {})) + timeline = SnapshotTimeline(data.get(KEY_HISTORY_TREE, {})) snapshot_json = json.dumps({k: v for k, v in data.items() if k != KEY_HISTORY_TREE}) snapshot = json.loads(snapshot_json) try: - htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") + timeline.record(snapshot, f"Mass update: {', '.join(selected_keys)}") except ValueError as e: ui.notify(f'Mass update failed: {e}', type='negative') return if state.db_enabled and state.current_project and state.db: - full_tree = htree.to_dict() + full_tree = timeline.to_dict() data[KEY_HISTORY_TREE] = full_tree db_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snapshot) - htree.strip_snapshots() - data[KEY_HISTORY_TREE] = htree.to_dict() + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() slim_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, slim_snapshot) else: - data[KEY_HISTORY_TREE] = htree.to_dict() + data[KEY_HISTORY_TREE] = timeline.to_dict() save_snapshot = json.loads(json.dumps(data)) await asyncio.to_thread(save_json, file_path, save_snapshot) ui.notify(f'Updated {len(targets)} sequences', type='positive') diff --git a/tab_timeline_ng.py b/tab_timeline_ng.py index 7e89b40..b7860f7 100644 --- a/tab_timeline_ng.py +++ b/tab_timeline_ng.py @@ -1,5 +1,5 @@ import asyncio -import hashlib +import copy import json import logging import time @@ -7,373 +7,15 @@ import time from nicegui import ui from state import AppState -from history_tree import HistoryTree +from snapshot_timeline import SnapshotTimeline, diff_snapshots from utils import save_json, load_json, sync_to_db, KEY_BATCH_DATA, KEY_HISTORY_TREE logger = logging.getLogger(__name__) -def _delete_nodes(htree, data, file_path, node_ids, state=None): - """Delete nodes with backup, branch cleanup, re-parenting, and head fallback.""" - if 'history_tree_backup' not in data: - data['history_tree_backup'] = [] - # Back up tree metadata only (no snapshot data) to avoid bloating JSON - backup = json.loads(json.dumps(htree.to_dict())) - for node in backup.get('nodes', {}).values(): - node.pop('data', None) - data['history_tree_backup'].append(backup) - data['history_tree_backup'] = data['history_tree_backup'][-10:] - # Save deleted node parents before removal (needed for branch re-pointing) - deleted_parents = {} - for nid in node_ids: - deleted_node = htree.nodes.get(nid) - if deleted_node: - deleted_parents[nid] = deleted_node.get('parent') - # Re-parent children of deleted nodes — walk up to find a surviving ancestor - for nid in node_ids: - surviving_parent = deleted_parents.get(nid) - while surviving_parent in node_ids: - surviving_parent = deleted_parents.get(surviving_parent) - for child in htree.nodes.values(): - if child.get('parent') == nid: - child['parent'] = surviving_parent - for nid in node_ids: - htree.nodes.pop(nid, None) - # Re-point branches whose tip was deleted to a surviving ancestor - for b, tip in list(htree.branches.items()): - if tip in node_ids: - new_tip = deleted_parents.get(tip) - while new_tip in node_ids: - new_tip = deleted_parents.get(new_tip) - if new_tip and new_tip in htree.nodes: - htree.branches[b] = new_tip - else: - del htree.branches[b] - if htree.head_id in node_ids: - if htree.nodes: - htree.head_id = sorted(htree.nodes.values(), - key=lambda x: x['timestamp'])[-1]['id'] - else: - htree.head_id = None - data[KEY_HISTORY_TREE] = htree.to_dict() - # Clean up DB snapshots for deleted nodes - if state and state.db_enabled and state.db and state.current_project: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - state.db.delete_node_snapshots(df['id'], set(node_ids)) - - -def _render_selection_picker(all_nodes, htree, state, refresh_fn): - """Multi-select picker for batch-deleting timeline nodes.""" - all_ids = [n['id'] for n in all_nodes] - - def fmt_option(nid): - n = htree.nodes[nid] - ts = time.strftime('%b %d %H:%M', time.localtime(n['timestamp'])) - note = n.get('note', 'Step') - head = ' (HEAD)' if nid == htree.head_id else '' - return f'{note} - {ts} ({nid[:6]}){head}' - - options = {nid: fmt_option(nid) for nid in all_ids} - - def on_selection_change(e): - state.timeline_selected_nodes = set(e.value) if e.value else set() - - ui.select( - options, - value=list(state.timeline_selected_nodes), - multiple=True, - label='Select nodes to delete:', - on_change=on_selection_change, - ).classes('w-full') - - with ui.row(): - def select_all(): - state.timeline_selected_nodes = set(all_ids) - refresh_fn() - def deselect_all(): - state.timeline_selected_nodes = set() - refresh_fn() - ui.button('Select All', on_click=select_all).props('flat dense') - ui.button('Deselect All', on_click=deselect_all).props('flat dense') - - -def _render_graph_or_log(mode, all_nodes, htree, selected_nodes, - selection_mode_on, toggle_select_fn, restore_fn, - selected=None): - """Render graph visualization or linear log view.""" - if mode in ('Horizontal', 'Vertical'): - direction = 'LR' if mode == 'Horizontal' else 'TB' - with ui.card().classes('w-full q-pa-md'): - try: - graph_dot = htree.generate_graph(direction=direction) - sel_id = selected.get('node_id') if selected else None - _render_graphviz(graph_dot, selected_node_id=sel_id) - except Exception as e: - ui.label(f'Graph Error: {e}').classes('text-negative') - - elif mode == 'Linear Log': - ui.label('Chronological list of all snapshots.').classes('text-caption') - for n in all_nodes: - is_head = n['id'] == htree.head_id - is_selected = n['id'] in selected_nodes - - card_style = '' - if is_selected: - card_style = 'background: rgba(239, 68, 68, 0.1) !important; border-left: 3px solid var(--negative);' - elif is_head: - card_style = 'background: var(--accent-subtle) !important; border-left: 3px solid var(--accent);' - with ui.card().classes('w-full q-mb-sm').style(card_style): - with ui.row().classes('w-full items-center'): - if selection_mode_on: - ui.checkbox( - '', - value=is_selected, - on_change=lambda e, nid=n['id']: toggle_select_fn( - nid, e.value), - ) - - icon = 'location_on' if is_head else 'circle' - ui.icon(icon).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label = f'{note} (Current)' if is_head else note - ui.label(label).classes('text-bold') - ui.label( - f'ID: {n["id"][:6]} - {ts}').classes('text-caption') - - if not is_head and not selection_mode_on: - ui.button( - 'Restore', - icon='restore', - on_click=lambda node=n: restore_fn(node), - ).props('flat dense color=primary') - - -def _render_batch_delete(htree, data, file_path, state, refresh_fn): - """Render batch delete controls for selected timeline nodes.""" - valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - state.timeline_selected_nodes = valid - count = len(valid) - if count == 0: - return - - ui.label( - f'{count} node{"s" if count != 1 else ""} selected for deletion.' - ).classes('text-warning q-mt-md') - - async def do_batch_delete(): - current_valid = state.timeline_selected_nodes & set(htree.nodes.keys()) - _delete_nodes(htree, data, file_path, current_valid, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - state.timeline_selected_nodes = set() - ui.notify( - f'Deleted {len(current_valid)} node{"s" if len(current_valid) != 1 else ""}!', - type='positive') - refresh_fn() - - ui.button( - f'Delete {count} Node{"s" if count != 1 else ""}', - icon='delete', - on_click=do_batch_delete, - ).props('color=negative') - - -def _walk_branch_nodes(htree, tip_id): - """Walk parent pointers from tip, returning nodes newest-first.""" - nodes = [] - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - visited.add(current) - nodes.append(htree.nodes[current]) - current = htree.nodes[current].get('parent') - return nodes - - -def _find_active_branch(htree): - """Return branch name whose tip == head_id, or None if detached.""" - if not htree.head_id: - return None - for b_name, tip_id in htree.branches.items(): - if tip_id == htree.head_id: - return b_name - return None - - -def _find_branch_for_node(htree, node_id): - """Return the branch name whose ancestry contains node_id, or None.""" - for b_name, tip_id in htree.branches.items(): - visited = set() - current = tip_id - while current and current in htree.nodes: - if current in visited: - break - if current == node_id: - return b_name - visited.add(current) - current = htree.nodes[current].get('parent') - return None - - -def _render_node_manager(all_nodes, htree, data, file_path, restore_fn, refresh_fn, - selected, state=None): - """Render branch-grouped node manager with restore, rename, delete, and preview.""" - ui.label('Manage Version').classes('section-header') - - active_branch = _find_active_branch(htree) - - # --- (a) Branch selector --- - def fmt_branch(b_name): - count = len(_walk_branch_nodes(htree, htree.branches.get(b_name))) - suffix = ' (active)' if b_name == active_branch else '' - return f'{b_name} ({count} nodes){suffix}' - - branch_options = {b: fmt_branch(b) for b in htree.branches} - - def on_branch_change(e): - selected['branch'] = e.value - tip = htree.branches.get(e.value) - if tip: - selected['node_id'] = tip - render_branch_nodes.refresh() - - ui.select( - branch_options, - value=selected['branch'], - label='Branch:', - on_change=on_branch_change, - ).classes('w-full') - - # --- (b) Node list + (c) Actions panel --- - @ui.refreshable - def render_branch_nodes(): - branch_name = selected['branch'] - tip_id = htree.branches.get(branch_name) - nodes = _walk_branch_nodes(htree, tip_id) if tip_id else [] - - if not nodes: - ui.label('No nodes on this branch.').classes('text-caption q-pa-sm') - return - - with ui.scroll_area().classes('w-full').style('max-height: 350px'): - for n in nodes: - nid = n['id'] - is_head = nid == htree.head_id - is_tip = nid == tip_id - is_selected = nid == selected['node_id'] - - card_style = '' - if is_selected: - card_style = 'border-left: 3px solid var(--primary);' - elif is_head: - card_style = 'border-left: 3px solid var(--accent);' - - with ui.card().classes('w-full q-mb-xs q-pa-xs').style(card_style): - with ui.row().classes('w-full items-center no-wrap'): - icon = 'location_on' if is_head else 'circle' - icon_size = 'sm' if is_head else 'xs' - ui.icon(icon, size=icon_size).classes( - 'text-primary' if is_head else 'text-grey') - - with ui.column().classes('col q-ml-xs').style('min-width: 0'): - note = n.get('note', 'Step') - ts = time.strftime('%b %d %H:%M', - time.localtime(n['timestamp'])) - label_text = note - lbl = ui.label(label_text).classes('text-body2 ellipsis') - if is_head: - lbl.classes('text-bold') - ui.label(f'{ts} \u2022 {nid[:6]}').classes( - 'text-caption text-grey') - - if is_head: - ui.badge('HEAD', color='amber').props('dense') - if is_tip and not is_head: - ui.badge('tip', color='green', outline=True).props('dense') - - def select_node(node_id=nid): - selected['node_id'] = node_id - render_branch_nodes.refresh() - - ui.button(icon='check_circle', on_click=select_node).props( - 'flat dense round size=sm' - ).tooltip('Select this node') - - # --- (c) Actions panel --- - sel_id = selected['node_id'] - if not sel_id or sel_id not in htree.nodes: - return - - sel_node = htree.nodes[sel_id] - sel_note = sel_node.get('note', 'Step') - is_head = sel_id == htree.head_id - - ui.separator().classes('q-my-sm') - ui.label(f'Selected: {sel_note} ({sel_id[:6]})').classes( - 'text-caption text-bold') - - with ui.row().classes('w-full items-end q-gutter-sm'): - if not is_head: - def restore_selected(): - if sel_id in htree.nodes: - restore_fn(htree.nodes[sel_id]) - ui.button('Restore', icon='restore', - on_click=restore_selected).props('color=primary dense') - - # Rename - rename_input = ui.input('Rename Label').classes('col').props('dense') - - async def rename_node(): - if sel_id in htree.nodes and rename_input.value: - htree.nodes[sel_id]['note'] = rename_input.value - data[KEY_HISTORY_TREE] = htree.to_dict() - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - ui.notify('Label updated', type='positive') - refresh_fn() - - ui.button('Update Label', on_click=rename_node).props('flat dense') - - # Danger zone - with ui.expansion('Danger Zone', icon='warning').classes( - 'w-full q-mt-sm').style('border-left: 3px solid var(--negative)'): - ui.label('Deleting a node cannot be undone.').classes('text-warning') - - async def delete_selected(): - if sel_id in htree.nodes: - _delete_nodes(htree, data, file_path, {sel_id}, state=state) - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state and state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - # Reset selection if branch was removed - if selected['branch'] not in htree.branches: - selected['branch'] = next(iter(htree.branches), None) - selected['node_id'] = htree.head_id - ui.notify('Node Deleted', type='positive') - refresh_fn() - - ui.button('Delete This Node', icon='delete', - on_click=delete_selected).props('color=negative dense') - - # Data preview - with ui.expansion('Data Preview', icon='preview').classes('w-full q-mt-sm'): - _render_data_preview(sel_id, htree, state=state, file_path=file_path) - - render_branch_nodes() - +# ====================================================================== +# Main entry point +# ====================================================================== def render_timeline_tab(state: AppState): t0 = time.perf_counter() @@ -383,266 +25,541 @@ def render_timeline_tab(state: AppState): tree_data = data.get(KEY_HISTORY_TREE, {}) if not tree_data: - ui.label('No history timeline exists. Make some changes in the Editor first!').classes( + ui.label('No version history exists. Make some changes in the Editor first!').classes( 'text-subtitle1 q-pa-md') return - htree = HistoryTree(tree_data) + timeline = SnapshotTimeline(tree_data) + if not timeline.snapshots: + ui.label('No snapshots found in history.').classes('text-subtitle1 q-pa-md') + return - # --- Shared selected-node state (survives refreshes, shared by graph + manager) --- - active_branch = _find_active_branch(htree) - default_branch = active_branch - if not default_branch and htree.head_id: - for b_name, tip_id in htree.branches.items(): - for n in _walk_branch_nodes(htree, tip_id): - if n['id'] == htree.head_id: - default_branch = b_name - break - if default_branch: - break - if not default_branch and htree.branches: - default_branch = next(iter(htree.branches)) - selected = {'node_id': htree.head_id, 'branch': default_branch} + # Local UI state + ui_state = { + 'selected_id': state.timeline_selected_id or timeline.current_id, + 'search': '', + 'filter': 'All', # All | Pinned | Auto + } if state.restored_indicator: ui.label(f'Editing Restored Version: {state.restored_indicator}').classes( 'text-info q-pa-sm') - # --- View mode + Selection toggle --- - with ui.row().classes('w-full items-center q-gutter-md q-mb-md'): - ui.label('Version History').classes('text-h6 col') - view_mode = ui.toggle( - ['Horizontal', 'Vertical', 'Linear Log'], - value='Horizontal', - ) - selection_mode = ui.switch('Select to Delete') + ui.label('Version History').classes('text-h6 q-mb-sm') - @ui.refreshable - def render_timeline(): - t_rt = time.perf_counter() - logger.info("render_timeline START (%d nodes)", len(htree.nodes)) - all_nodes = sorted(htree.nodes.values(), key=lambda x: x['timestamp'], reverse=True) - selected_nodes = state.timeline_selected_nodes if selection_mode.value else set() + # Mutable container so left/right panels can cross-reference each other's refreshables + panels: dict = {} - if selection_mode.value: - _render_selection_picker(all_nodes, htree, state, render_timeline.refresh) + # ====================================================================== + # Splitter layout: 35% left (list) / 65% right (detail) + # ====================================================================== + with ui.splitter(value=35).classes('w-full').style('min-height: 600px') as splitter: - _render_graph_or_log( - view_mode.value, all_nodes, htree, selected_nodes, - selection_mode.value, _toggle_select, _restore_and_refresh, - selected=selected) + # ============================================================== + # LEFT PANEL — Snapshot list + # ============================================================== + with splitter.before: + with ui.column().classes('w-full q-pa-sm'): + # Search + filter + search_input = ui.input( + placeholder='Search notes...', + ).classes('w-full').props('dense outlined clearable') - if selection_mode.value and state.timeline_selected_nodes: - _render_batch_delete(htree, data, file_path, state, render_timeline.refresh) + with ui.row().classes('w-full q-gutter-xs'): + filter_toggle = ui.toggle( + ['All', 'Pinned', 'Auto'], value='All', + ).props('dense no-caps') - with ui.card().classes('w-full q-pa-md q-mt-md'): - _render_node_manager( - all_nodes, htree, data, file_path, - _restore_and_refresh, render_timeline.refresh, - selected, state=state) - logger.info("render_timeline END (%.3fs)", time.perf_counter() - t_rt) + @ui.refreshable + def render_snapshot_list(): + _render_snapshot_list( + timeline, ui_state, data, file_path, state, + render_snapshot_list, panels) - def _toggle_select(nid, checked): - if checked: - state.timeline_selected_nodes.add(nid) - else: - state.timeline_selected_nodes.discard(nid) - render_timeline.refresh() + panels['list'] = render_snapshot_list - async def _restore_and_refresh(node): - await _restore_node(data, node, htree, file_path, state) - # Refresh all tabs (batch, raw, timeline) so they pick up the restored data - state._render_main.refresh() + def _on_search(e): + ui_state['search'] = search_input.value or '' + render_snapshot_list.refresh() + + def _on_filter(e): + ui_state['filter'] = e.value + render_snapshot_list.refresh() + + search_input.on('update:model-value', _on_search) + filter_toggle.on_value_change(_on_filter) + + render_snapshot_list() + + # ============================================================== + # RIGHT PANEL — Detail tabs + # ============================================================== + with splitter.after: + @ui.refreshable + def render_detail_panel(): + _render_detail_panel(timeline, ui_state, data, file_path, state, + panels) + + panels['detail'] = render_detail_panel + render_detail_panel() - view_mode.on_value_change(lambda _: render_timeline.refresh()) - selection_mode.on_value_change(lambda _: render_timeline.refresh()) - render_timeline() logger.info("render_timeline_tab END (%.3fs)", time.perf_counter() - t0) - # --- Poll for graph node clicks (JS → Python bridge) --- - graph_timer = None - async def _poll_graph_click(): - if view_mode.value == 'Linear Log': - return - try: - result = await ui.run_javascript( - 'const v = window.graphSelectedNode;' - 'window.graphSelectedNode = null; v;' - ) - except Exception: - # Deactivate timer if parent slot was deleted - if graph_timer is not None: - graph_timer.active = False - return - if not result: - return - node_id = str(result) - if node_id not in htree.nodes: - return - branch = _find_branch_for_node(htree, node_id) - if branch: - selected['branch'] = branch - selected['node_id'] = node_id - render_timeline.refresh() +# ====================================================================== +# Left panel: snapshot list +# ====================================================================== - graph_timer = ui.timer(0.5, _poll_graph_click) +def _render_snapshot_list(timeline, ui_state, data, file_path, state, + refresh_list, panels): + snapshots = sorted(timeline.snapshots.values(), + key=lambda s: s['timestamp'], reverse=True) - def _cleanup_timer(): - if graph_timer is not None: - graph_timer.active = False - ui.context.client.on_disconnect(_cleanup_timer) + # Apply filters + search_term = ui_state.get('search', '').lower() + filter_mode = ui_state.get('filter', 'All') + if search_term: + snapshots = [s for s in snapshots + if search_term in s.get('note', '').lower()] + if filter_mode == 'Pinned': + snapshots = [s for s in snapshots if s.get('pinned')] + elif filter_mode == 'Auto': + snapshots = [s for s in snapshots if s.get('auto')] -_graphviz_svg_cache: dict[str, str] = {} -_GRAPHVIZ_CACHE_MAX = 20 - - -def _render_graphviz(dot_source: str, selected_node_id: str | None = None): - """Render graphviz DOT source as interactive SVG with click-to-select.""" - try: - import graphviz - t_gv = time.perf_counter() - cache_key = hashlib.md5(dot_source.encode()).hexdigest() - svg = _graphviz_svg_cache.get(cache_key) - if svg is None: - src = graphviz.Source(dot_source) - svg = src.pipe(format='svg').decode('utf-8') - if len(_graphviz_svg_cache) >= _GRAPHVIZ_CACHE_MAX: - _graphviz_svg_cache.pop(next(iter(_graphviz_svg_cache))) - _graphviz_svg_cache[cache_key] = svg - logger.info("_render_graphviz MISS (generated): %.3fs", time.perf_counter() - t_gv) - else: - logger.info("_render_graphviz HIT (cached): %.3fs", time.perf_counter() - t_gv) - - sel_escaped = json.dumps(selected_node_id or '')[1:-1] # strip quotes, get JS-safe content - - # CSS inline (allowed), JS via run_javascript (script tags blocked) - css = '''''' - - ui.html( - f'{css}
' - f'{svg}
' - ) - - # Find container by class with retry for Vue async render - ui.run_javascript(f''' - (function attempt(tries) {{ - var container = document.querySelector('.timeline-graph'); - if (!container || !container.querySelector('g.node')) {{ - if (tries < 20) setTimeout(function() {{ attempt(tries + 1); }}, 100); - return; - }} - container.querySelectorAll('g.node').forEach(function(g) {{ - g.addEventListener('click', function() {{ - var title = g.querySelector('title'); - if (title) {{ - window.graphSelectedNode = title.textContent.trim(); - container.querySelectorAll('g.node.selected').forEach( - function(el) {{ el.classList.remove('selected'); }}); - g.classList.add('selected'); - }} - }}); - }}); - var selId = '{sel_escaped}'; - if (selId) {{ - container.querySelectorAll('g.node').forEach(function(g) {{ - var title = g.querySelector('title'); - if (title && title.textContent.trim() === selId) {{ - g.classList.add('selected'); - }} - }}); - }} - }})(0); - ''') - except ImportError: - ui.label('Install graphviz Python package for graph rendering.').classes('text-warning') - ui.code(dot_source).classes('w-full') - except Exception as e: - ui.label(f'Graph rendering error: {e}').classes('text-negative') - - -async def _restore_node(data, node, htree, file_path, state: AppState): - """Restore a history node as the current version (full replace, not merge).""" - t0 = time.perf_counter() - logger.info("_restore_node START: %s", node.get('note', 'Step')) - # Load snapshot from DB on demand (nodes no longer hold data in memory) - raw_snap = node.get('data') - if not raw_snap and state.db_enabled and state.db and state.current_project: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - raw_snap = await asyncio.to_thread( - state.db.get_node_snapshot, df['id'], node['id']) - if not raw_snap: - # Last resort: read from JSON file on disk - raw_file, _ = await asyncio.to_thread(load_json, file_path) - tree_on_disk = raw_file.get(KEY_HISTORY_TREE, {}) - raw_snap = tree_on_disk.get('nodes', {}).get(node['id'], {}).get('data', {}) - node_data = json.loads(json.dumps(raw_snap)) if raw_snap else {} - # Preserve the history tree before clearing - preserved_tree = data.get(KEY_HISTORY_TREE) - preserved_backup = data.get('history_tree_backup') - data.clear() - data.update(node_data) - # Re-attach history tree (not part of snapshot data) - if preserved_tree is not None: - data[KEY_HISTORY_TREE] = preserved_tree - if preserved_backup is not None: - data['history_tree_backup'] = preserved_backup - htree.head_id = node['id'] - data[KEY_HISTORY_TREE] = htree.to_dict() - snapshot = json.loads(json.dumps(data)) - await asyncio.to_thread(save_json, file_path, snapshot) - if state.db_enabled and state.current_project and state.db: - await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) - label = f"{node.get('note', 'Step')} ({node['id'][:4]})" - state.restored_indicator = label - logger.info("_restore_node END (%.3fs)", time.perf_counter() - t0) - ui.notify('Restored!', type='positive') - - -def _render_data_preview(nid, htree, state: AppState = None, file_path=None): - """Render a read-only preview of the selected node's data.""" - if not nid or nid not in htree.nodes: - ui.label('No node selected.').classes('text-caption') + if not snapshots: + ui.label('No snapshots match your filter.').classes('text-caption q-pa-md') return - # Load snapshot from DB on demand (not stored in memory) - node_data = htree.nodes[nid].get('data') - if not node_data and state and state.db_enabled and state.db and state.current_project and file_path: - df = state.db.get_data_file_by_names(state.current_project, file_path.stem) - if df: - node_data = state.db.get_node_snapshot(df['id'], nid) - if not node_data and file_path: - # Disk fallback: read snapshot from JSON file - try: - raw_data, _ = load_json(file_path) - tree_on_disk = raw_data.get(KEY_HISTORY_TREE, {}) - node_data = tree_on_disk.get('nodes', {}).get(nid, {}).get('data') - except Exception: - pass - if not node_data: + with ui.scroll_area().classes('w-full').style('max-height: 520px'): + for snap in snapshots: + sid = snap['id'] + is_current = sid == timeline.current_id + is_selected = sid == ui_state.get('selected_id') + is_pinned = snap.get('pinned', False) + is_auto = snap.get('auto', False) + + # Card styling + border = '' + if is_current: + border = 'border-left: 4px solid #eebb00;' + if is_selected: + border = 'border-left: 4px solid #4caf50;' + bg = 'background: rgba(76,175,80,0.08) !important;' if is_selected else '' + + def select_snap(snap_id=sid): + ui_state['selected_id'] = snap_id + state.timeline_selected_id = snap_id + refresh_list.refresh() + detail = panels.get('detail') + if detail is not None: + detail.refresh() + + with ui.card().classes('w-full q-mb-xs q-pa-xs cursor-pointer').style( + f'{border} {bg}').on('click', select_snap): + with ui.row().classes('w-full items-center no-wrap'): + # Icon + if is_pinned: + icon_name = 'push_pin' + icon_cls = 'text-amber' + elif is_auto: + icon_name = 'bolt' + icon_cls = 'text-grey' + else: + icon_name = 'save' + icon_cls = 'text-primary' + ui.icon(icon_name, size='sm').classes(icon_cls) + + # Text + with ui.column().classes('col q-ml-xs').style('min-width: 0'): + note = snap.get('note', 'Snapshot') + lbl = ui.label(note).classes('text-body2 ellipsis') + if is_current: + lbl.classes('text-bold') + ts = time.strftime('%b %d %H:%M', + time.localtime(snap['timestamp'])) + seq_count = snap.get('seq_count', '?') + ui.label(f'{ts} \u00b7 {seq_count} seqs').classes( + 'text-caption text-grey') + + # Badges + if is_current: + ui.badge('current', color='amber').props('dense') + + # Pin toggle + async def toggle_pin(snap_id=sid): + timeline.toggle_pin(snap_id) + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + refresh_list.refresh() + + pin_icon = 'push_pin' if is_pinned else 'o_push_pin' + ui.button(icon=pin_icon, on_click=toggle_pin).props( + 'flat dense round size=xs').on('click.stop', lambda: None) + + +# ====================================================================== +# Right panel: detail tabs (Preview / Compare / Cherry-pick) +# ====================================================================== + +def _render_detail_panel(timeline, ui_state, data, file_path, state, + panels): + sel_id = ui_state.get('selected_id') + if not sel_id or sel_id not in timeline.snapshots: + ui.label('Select a snapshot from the list.').classes('text-caption q-pa-lg') + return + + def _refresh_both(): + """Refresh both list and detail panels.""" + lp = panels.get('list') + dp = panels.get('detail') + if lp: + lp.refresh() + if dp: + dp.refresh() + + snap = timeline.snapshots[sel_id] + note = snap.get('note', 'Snapshot') + ts = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(snap['timestamp'])) + ui.label(f'{note}').classes('text-subtitle1 text-bold') + ui.label(f'{ts} \u2022 ID: {sel_id}').classes('text-caption text-grey q-mb-sm') + + # Action buttons + with ui.row().classes('q-gutter-sm q-mb-sm'): + is_current = sel_id == timeline.current_id + + if not is_current: + async def restore_full(): + await _restore_snapshot(data, sel_id, timeline, file_path, state) + state._render_main.refresh() + + ui.button('Restore Full', icon='restore', + on_click=restore_full).props('color=primary dense') + + # Rename + rename_input = ui.input(placeholder='New note...').props('dense outlined').classes('w-48') + + async def rename(): + if rename_input.value and sel_id in timeline.snapshots: + timeline.snapshots[sel_id]['note'] = rename_input.value + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + ui.notify('Note updated', type='positive') + _refresh_both() + + ui.button('Rename', on_click=rename).props('flat dense') + + # Delete + async def delete_snap(): + timeline.delete(sel_id) + # Clean up DB snapshots + if state.db_enabled and state.db and state.current_project: + df = state.db.get_data_file_by_names(state.current_project, file_path.stem) + if df: + state.db.delete_node_snapshots(df['id'], {sel_id}) + data[KEY_HISTORY_TREE] = timeline.to_dict() + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + ui_state['selected_id'] = timeline.current_id + state.timeline_selected_id = timeline.current_id + ui.notify('Snapshot deleted', type='positive') + _refresh_both() + + ui.button(icon='delete', on_click=delete_snap).props('flat dense color=negative') + + # Sub-tabs + with ui.tabs().classes('w-full') as tabs: + preview_tab = ui.tab('Preview', icon='visibility') + compare_tab = ui.tab('Compare', icon='compare') + cherry_tab = ui.tab('Cherry-pick', icon='content_paste') + + with ui.tab_panels(tabs, value=preview_tab).classes('w-full'): + with ui.tab_panel(preview_tab): + _render_preview_tab(sel_id, timeline, state, file_path) + + with ui.tab_panel(compare_tab): + _render_compare_tab(sel_id, timeline, data, state, file_path) + + with ui.tab_panel(cherry_tab): + _render_cherry_pick_tab(sel_id, timeline, data, file_path, state, + panels) + + +# ====================================================================== +# Tab 1: Preview +# ====================================================================== + +def _render_preview_tab(sel_id, timeline, state, file_path): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: ui.label('Snapshot data not available.').classes('text-caption text-warning') return - batch_list = node_data.get(KEY_BATCH_DATA, []) - if batch_list and isinstance(batch_list, list) and len(batch_list) > 0: - ui.label(f'This snapshot contains {len(batch_list)} sequences.').classes('text-caption') + batch_list = snap_data.get(KEY_BATCH_DATA, []) + if batch_list and isinstance(batch_list, list): + ui.label(f'{len(batch_list)} sequences in this snapshot.').classes('text-caption') for i, seq_data in enumerate(batch_list): seq_num = seq_data.get('sequence_number', i + 1) with ui.expansion(f'Sequence #{seq_num}', value=(i == 0)): _render_preview_fields(seq_data) else: - _render_preview_fields(node_data) + _render_preview_fields(snap_data) + + +# ====================================================================== +# Tab 2: Compare +# ====================================================================== + +def _render_compare_tab(sel_id, timeline, data, state, file_path): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: + ui.label('Snapshot data not available.').classes('text-caption text-warning') + return + + old_batch = snap_data.get(KEY_BATCH_DATA, []) + new_batch = data.get(KEY_BATCH_DATA, []) + + if not old_batch and not new_batch: + ui.label('No batch data to compare.').classes('text-caption') + return + + diffs = diff_snapshots(old_batch, new_batch) + + show_all = ui.switch('Show unchanged', value=False) + + @ui.refreshable + def render_diff(): + any_diff = False + for d in diffs: + if d['status'] == 'unchanged' and not show_all.value: + continue + any_diff = True + seq_num = d['seq_num'] + status = d['status'] + + # Header styling + if status == 'added': + icon = 'add_circle' + color = 'text-positive' + label = f'Sequence #{seq_num} \u2014 ADDED (not in snapshot)' + elif status == 'removed': + icon = 'remove_circle' + color = 'text-negative' + label = f'Sequence #{seq_num} \u2014 REMOVED (not in current)' + elif status == 'changed': + icon = 'change_circle' + color = 'text-warning' + label = f'Sequence #{seq_num} \u2014 {len(d["changes"])} field{"s" if len(d["changes"]) != 1 else ""} changed' + else: + icon = 'check_circle' + color = 'text-grey' + label = f'Sequence #{seq_num} \u2014 No changes' + + with ui.expansion(label, icon=icon).classes(f'w-full {color}'): + if status == 'changed' and d['changes']: + # Table of field changes + columns = [ + {'name': 'field', 'label': 'Field', 'field': 'field', 'align': 'left'}, + {'name': 'old', 'label': 'Snapshot', 'field': 'old', 'align': 'left'}, + {'name': 'new', 'label': 'Current', 'field': 'new', 'align': 'left'}, + ] + rows = [] + for c in d['changes']: + rows.append({ + 'field': c['field'], + 'old': _truncate(c['old']), + 'new': _truncate(c['new']), + }) + ui.table(columns=columns, rows=rows, row_key='field').classes( + 'w-full').props('dense flat bordered') + elif status in ('added', 'removed'): + ui.label('Entire sequence differs.').classes('text-caption') + + if not any_diff: + ui.label('All sequences are identical.').classes('text-caption q-pa-md') + + show_all.on_value_change(lambda _: render_diff.refresh()) + render_diff() + + +# ====================================================================== +# Tab 3: Cherry-pick Restore +# ====================================================================== + +def _render_cherry_pick_tab(sel_id, timeline, data, file_path, state, + panels): + snap_data = _load_snapshot_data(sel_id, timeline, state, file_path) + if not snap_data: + ui.label('Snapshot data not available.').classes('text-caption text-warning') + return + + old_batch = snap_data.get(KEY_BATCH_DATA, []) + if not old_batch: + ui.label('No sequences in this snapshot.').classes('text-caption') + return + + ui.label('Select sequences and fields to restore from this snapshot.').classes( + 'text-caption q-mb-sm') + + mode = ui.toggle(['Whole sequences', 'Selected fields'], value='Whole sequences').props( + 'dense no-caps') + + # Build checkboxes per sequence + seq_checks: dict[int, ui.checkbox] = {} + field_checks: dict[int, dict[str, ui.checkbox]] = {} + + for seq_item in old_batch: + seq_num = int(seq_item.get('sequence_number', 0)) + seq_cb = ui.checkbox(f'Sequence #{seq_num}') + seq_checks[seq_num] = seq_cb + + with ui.expansion(f'Fields for #{seq_num}').classes('w-full q-ml-lg'): + field_checks[seq_num] = {} + for k in sorted(seq_item.keys()): + if k == 'sequence_number': + continue + val_str = _truncate(seq_item.get(k)) + fcb = ui.checkbox(f'{k}: {val_str}') + field_checks[seq_num][k] = fcb + + async def apply_cherry_pick(): + current_batch = data.get(KEY_BATCH_DATA, []) + curr_by_seq = {int(s.get('sequence_number', 0)): s for s in current_batch} + old_by_seq = {int(s.get('sequence_number', 0)): s for s in old_batch} + + applied = 0 + for seq_num, cb in seq_checks.items(): + if not cb.value: + continue + if seq_num not in old_by_seq: + continue + + if mode.value == 'Whole sequences': + # Replace or add entire sequence + restored = copy.deepcopy(old_by_seq[seq_num]) + if seq_num in curr_by_seq: + # Find and replace in-place + for i, s in enumerate(current_batch): + if int(s.get('sequence_number', 0)) == seq_num: + current_batch[i] = restored + break + else: + current_batch.append(restored) + applied += 1 + else: + # Selected fields only + if seq_num not in curr_by_seq: + continue + target = curr_by_seq[seq_num] + fields = field_checks.get(seq_num, {}) + for field_name, fcb in fields.items(): + if fcb.value and field_name in old_by_seq[seq_num]: + target[field_name] = copy.deepcopy(old_by_seq[seq_num][field_name]) + applied += 1 + + if applied == 0: + ui.notify('Nothing selected to restore.', type='warning') + return + + data[KEY_BATCH_DATA] = current_batch + + # Auto-snapshot noting the cherry-pick + snap_note = timeline.snapshots.get(sel_id, {}).get('note', 'unknown') + snap_json = json.dumps({k: v for k, v in data.items() + if k != KEY_HISTORY_TREE}) + snap_payload = json.loads(snap_json) + timeline.record(snap_payload, note=f'Cherry-pick from "{snap_note}"') + if state.db_enabled and state.current_project and state.db: + data[KEY_HISTORY_TREE] = timeline.to_dict() + db_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, db_snap) + timeline.strip_snapshots() + data[KEY_HISTORY_TREE] = timeline.to_dict() + save_snap = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, save_snap) + ui.notify(f'Applied {applied} item{"s" if applied != 1 else ""}!', type='positive') + for p in ('list', 'detail'): + ref = panels.get(p) + if ref: + ref.refresh() + + ui.button('Apply Selected', icon='check', on_click=apply_cherry_pick).props( + 'color=primary q-mt-md') + + +# ====================================================================== +# Shared helpers +# ====================================================================== + +def _load_snapshot_data(snap_id, timeline, state, file_path): + """Load snapshot data from inline, DB, or disk fallback.""" + snap_data = timeline.get_snapshot_data(snap_id) + if snap_data: + return snap_data + + # Try DB + if state and state.db_enabled and state.db and state.current_project and file_path: + df = state.db.get_data_file_by_names(state.current_project, file_path.stem) + if df: + snap_data = state.db.get_node_snapshot(df['id'], snap_id) + if snap_data: + return snap_data + + # Disk fallback + if file_path: + try: + raw_data, _ = load_json(file_path) + tree_on_disk = raw_data.get(KEY_HISTORY_TREE, {}) + # New format + entry = tree_on_disk.get('snapshots', {}).get(snap_id) + if entry and 'data' in entry: + return entry['data'] + # Old format + entry = tree_on_disk.get('nodes', {}).get(snap_id) + if entry and 'data' in entry: + return entry['data'] + except Exception as e: + logger.warning("Failed to load snapshot %s from disk: %s", snap_id, e) + return None + + +async def _restore_snapshot(data, snap_id, timeline, file_path, state): + """Restore a snapshot as the current version (full replace).""" + snap_data = _load_snapshot_data(snap_id, timeline, state, file_path) + if not snap_data: + ui.notify('Snapshot data not available', type='negative') + return + + node_data = json.loads(json.dumps(snap_data)) + + # Preserve history tree + preserved_tree = data.get(KEY_HISTORY_TREE) + preserved_backup = data.get('history_tree_backup') + data.clear() + data.update(node_data) + if preserved_tree is not None: + data[KEY_HISTORY_TREE] = preserved_tree + if preserved_backup is not None: + data['history_tree_backup'] = preserved_backup + + timeline.current_id = snap_id + data[KEY_HISTORY_TREE] = timeline.to_dict() + + snapshot = json.loads(json.dumps(data)) + await asyncio.to_thread(save_json, file_path, snapshot) + if state.db_enabled and state.current_project and state.db: + await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, snapshot) + + note = timeline.snapshots.get(snap_id, {}).get('note', 'Snapshot') + label = f"{note} ({snap_id[:4]})" + state.restored_indicator = label + ui.notify('Restored!', type='positive') def _render_preview_fields(item_data: dict): @@ -684,3 +601,9 @@ def _render_preview_fields(item_data: dict): value=str(item_data.get('vace schedule', ''))).props('readonly outlined') ui.input('Video Path', value=str(item_data.get('video file path', ''))).props('readonly outlined') + + +def _truncate(val, max_len=60): + """Truncate a value for display.""" + s = str(val) if val is not None else '' + return (s[:max_len] + '...') if len(s) > max_len else s diff --git a/tests/test_db.py b/tests/test_db.py index bea102f..e027ea6 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -208,10 +208,10 @@ class TestHistoryTrees: def test_upsert_updates(self, db): pid = db.create_project("p1", "/p1") df_id = db.create_data_file(pid, "batch", "generic") - db.save_history_tree(df_id, {"v": 1}) - db.save_history_tree(df_id, {"v": 2}) + db.save_history_tree(df_id, {"snapshots": {}, "v": 1}) + db.save_history_tree(df_id, {"snapshots": {}, "v": 2}) result = db.get_history_tree(df_id) - assert result == {"v": 2} + assert result == {"snapshots": {}, "v": 2} def test_get_nonexistent(self, db): pid = db.create_project("p1", "/p1") diff --git a/tests/test_snapshot_timeline.py b/tests/test_snapshot_timeline.py new file mode 100644 index 0000000..db02b1d --- /dev/null +++ b/tests/test_snapshot_timeline.py @@ -0,0 +1,159 @@ +import pytest +from snapshot_timeline import SnapshotTimeline, diff_snapshots + + +def test_record_creates_snapshot(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": [{"seed": 42}]}, note="first") + assert sid in tl.snapshots + assert tl.current_id == sid + assert tl.snapshots[sid]["note"] == "first" + assert tl.snapshots[sid]["auto"] is False + assert tl.snapshots[sid]["seq_count"] == 1 + + +def test_record_auto_flag(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="auto save", auto=True) + assert tl.snapshots[sid]["auto"] is True + + +def test_multiple_records(): + tl = SnapshotTimeline({}) + id1 = tl.record({"batch_data": [{"a": 1}]}, note="one") + id2 = tl.record({"batch_data": [{"b": 2}]}, note="two") + assert len(tl.snapshots) == 2 + assert tl.current_id == id2 + + +def test_to_dict_roundtrip(): + tl = SnapshotTimeline({}) + tl.record({"batch_data": [{"x": 1}]}, note="test") + d = tl.to_dict() + tl2 = SnapshotTimeline(d) + assert tl2.current_id == tl.current_id + assert set(tl2.snapshots.keys()) == set(tl.snapshots.keys()) + + +def test_migrate_from_history_tree(): + """Old HistoryTree format should be flattened into snapshots.""" + old_data = { + "nodes": { + "aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First", "data": {"batch_data": [{"seed": 1}]}}, + "bbb": {"id": "bbb", "parent": "aaa", "timestamp": 2000, "note": "Second", "data": {"batch_data": [{"seed": 2}]}}, + }, + "branches": {"main": "bbb"}, + "head_id": "bbb", + } + tl = SnapshotTimeline(old_data) + assert len(tl.snapshots) == 2 + assert tl.current_id == "bbb" + assert tl.snapshots["aaa"]["note"] == "First" + assert tl.snapshots["bbb"]["note"] == "Second" + # Data should be preserved + assert tl.snapshots["aaa"]["data"]["batch_data"] == [{"seed": 1}] + + +def test_migrate_from_history_tree_no_data(): + """Slim tree nodes (no inline data) should still migrate.""" + old_data = { + "nodes": { + "aaa": {"id": "aaa", "parent": None, "timestamp": 1000, "note": "First"}, + }, + "branches": {"main": "aaa"}, + "head_id": "aaa", + } + tl = SnapshotTimeline(old_data) + assert len(tl.snapshots) == 1 + assert tl.snapshots["aaa"]["seq_count"] == 0 + + +def test_migrate_legacy_prompt_history(): + legacy = { + "prompt_history": [ + {"note": "A", "seed": 1}, + {"note": "B", "seed": 2}, + ] + } + tl = SnapshotTimeline(legacy) + assert len(tl.snapshots) == 2 + assert tl.current_id is not None + + +def test_toggle_pin(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="test") + assert tl.snapshots[sid]["pinned"] is False + result = tl.toggle_pin(sid) + assert result is True + assert tl.snapshots[sid]["pinned"] is True + result = tl.toggle_pin(sid) + assert result is False + + +def test_delete_snapshot(): + tl = SnapshotTimeline({}) + id1 = tl.record({"batch_data": []}, note="one") + id2 = tl.record({"batch_data": []}, note="two") + tl.delete(id2) + assert id2 not in tl.snapshots + assert tl.current_id == id1 + + +def test_delete_all_snapshots(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": []}, note="only") + tl.delete(sid) + assert len(tl.snapshots) == 0 + assert tl.current_id is None + + +def test_strip_snapshots(): + tl = SnapshotTimeline({}) + tl.record({"batch_data": [{"a": 1}]}, note="test") + tl.strip_snapshots() + for snap in tl.snapshots.values(): + assert "data" not in snap + + +def test_get_snapshot_data(): + tl = SnapshotTimeline({}) + sid = tl.record({"batch_data": [{"x": 1}]}, note="test") + data = tl.get_snapshot_data(sid) + assert data == {"batch_data": [{"x": 1}]} + assert tl.get_snapshot_data("nonexistent") is None + + +# --- diff_snapshots tests --- + +def test_diff_unchanged(): + batch = [{"sequence_number": 1, "seed": 42}] + result = diff_snapshots(batch, batch) + assert len(result) == 1 + assert result[0]["status"] == "unchanged" + assert result[0]["changes"] == [] + + +def test_diff_changed(): + old = [{"sequence_number": 1, "seed": 42, "cfg": 1.5}] + new = [{"sequence_number": 1, "seed": 99, "cfg": 1.5}] + result = diff_snapshots(old, new) + assert result[0]["status"] == "changed" + assert len(result[0]["changes"]) == 1 + assert result[0]["changes"][0]["field"] == "seed" + assert result[0]["changes"][0]["old"] == 42 + assert result[0]["changes"][0]["new"] == 99 + + +def test_diff_added_and_removed(): + old = [{"sequence_number": 1, "seed": 1}] + new = [{"sequence_number": 2, "seed": 2}] + result = diff_snapshots(old, new) + assert len(result) == 2 + statuses = {r["seq_num"]: r["status"] for r in result} + assert statuses[1] == "removed" + assert statuses[2] == "added" + + +def test_diff_empty(): + assert diff_snapshots([], []) == [] diff --git a/utils.py b/utils.py index 91f151a..ab5ed85 100644 --- a/utils.py +++ b/utils.py @@ -305,38 +305,47 @@ def sync_to_db(db, project_name: str, file_path: Path, data: dict) -> None: else: db.conn.execute("DELETE FROM sequences WHERE data_file_id = ?", (df_id,)) - # Sync history tree (extract node snapshots into separate table) + # Sync history tree (extract snapshot data into separate table) + # Supports both new format (snapshots dict) and old format (nodes dict) history_tree = data.get(KEY_HISTORY_TREE) if history_tree and isinstance(history_tree, dict): - nodes = history_tree.get("nodes", {}) + # Detect format: new has "snapshots", old has "nodes" + if "snapshots" in history_tree: + entries = history_tree.get("snapshots", {}) + else: + entries = history_tree.get("nodes", {}) slim_tree = dict(history_tree) - slim_nodes = {} - for nid, node in nodes.items(): - snap = node.get("data") + slim_entries = {} + for eid, entry in entries.items(): + snap = entry.get("data") if snap: db.conn.execute( "INSERT INTO history_snapshots (data_file_id, node_id, snapshot_data, updated_at) " "VALUES (?, ?, ?, ?) " "ON CONFLICT(data_file_id, node_id) DO UPDATE SET " "snapshot_data=excluded.snapshot_data, updated_at=excluded.updated_at", - (df_id, nid, json.dumps(snap), now), + (df_id, eid, json.dumps(snap), now), ) - slim_nodes[nid] = {k: v for k, v in node.items() if k != "data"} - slim_tree["nodes"] = slim_nodes + slim_entries[eid] = {k: v for k, v in entry.items() if k != "data"} + # Write back slim version using the correct key + if "snapshots" in history_tree: + slim_tree["snapshots"] = slim_entries + else: + slim_tree["nodes"] = slim_entries db.conn.execute( "INSERT INTO history_trees (data_file_id, tree_data, updated_at) " "VALUES (?, ?, ?) " "ON CONFLICT(data_file_id) DO UPDATE SET tree_data=excluded.tree_data, updated_at=excluded.updated_at", (df_id, json.dumps(slim_tree), now), ) - # Clean up orphaned snapshots for nodes no longer in tree - current_node_ids = set(nodes.keys()) - if current_node_ids: - placeholders = ",".join("?" for _ in current_node_ids) + # Clean up orphaned snapshots + current_ids = set(entries.keys()) + if current_ids: + placeholders = ",".join("?" for _ in current_ids) db.conn.execute( f"DELETE FROM history_snapshots WHERE data_file_id = ? " f"AND node_id NOT IN ({placeholders})", - (df_id, *current_node_ids), + (df_id, *current_ids), ) else: db.conn.execute(