diff --git a/app.py b/app.py index a190fd9..210d660 100644 --- a/app.py +++ b/app.py @@ -5,7 +5,8 @@ from pathlib import Path # --- Import Custom Modules --- from utils import ( load_config, save_config, load_snippets, save_snippets, - load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR + load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR, + KEY_BATCH_DATA, KEY_PROMPT_HISTORY, ) from tab_single import render_single_editor from tab_batch import render_batch_processor @@ -47,37 +48,51 @@ with st.sidebar: st.header("📂 Navigator") # --- Path Navigator --- - new_path = st.text_input("Current Path", value=str(st.session_state.current_dir)) + # Sync widget key with current_dir so the text input always reflects the actual path + if "nav_path_input" not in st.session_state: + st.session_state.nav_path_input = str(st.session_state.current_dir) + + new_path = st.text_input("Current Path", key="nav_path_input") if new_path != str(st.session_state.current_dir): p = Path(new_path).resolve() if p.exists() and p.is_dir(): - # Restrict navigation to the allowed base directory - try: - p.relative_to(ALLOWED_BASE_DIR) - except ValueError: - st.error(f"Access denied: path must be under {ALLOWED_BASE_DIR}") - else: - st.session_state.current_dir = p - st.session_state.config['last_dir'] = str(p) + st.session_state.current_dir = p + st.session_state.config['last_dir'] = str(p) + save_config(st.session_state.current_dir, st.session_state.config['favorites']) + st.rerun() + elif new_path.strip(): + st.error(f"Path does not exist or is not a directory: {new_path}") + + # --- Favorites System --- + pin_col, unpin_col = st.columns(2) + with pin_col: + if st.button("📌 Pin Folder", use_container_width=True): + if str(st.session_state.current_dir) not in st.session_state.config['favorites']: + st.session_state.config['favorites'].append(str(st.session_state.current_dir)) save_config(st.session_state.current_dir, st.session_state.config['favorites']) st.rerun() - # --- Favorites System --- - if st.button("📌 Pin Current Folder"): - if str(st.session_state.current_dir) not in st.session_state.config['favorites']: - st.session_state.config['favorites'].append(str(st.session_state.current_dir)) - save_config(st.session_state.current_dir, st.session_state.config['favorites']) + favorites = st.session_state.config['favorites'] + if favorites: + fav_selection = st.radio( + "Jump to:", + ["Select..."] + favorites, + index=0, + label_visibility="collapsed" + ) + if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir): + st.session_state.current_dir = Path(fav_selection) + st.session_state.nav_path_input = fav_selection st.rerun() - fav_selection = st.radio( - "Jump to:", - ["Select..."] + st.session_state.config['favorites'], - index=0, - label_visibility="collapsed" - ) - if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir): - st.session_state.current_dir = Path(fav_selection) - st.rerun() + # Unpin buttons for each favorite + for fav in favorites: + fc1, fc2 = st.columns([4, 1]) + fc1.caption(fav) + if fc2.button("❌", key=f"unpin_{fav}"): + st.session_state.config['favorites'].remove(fav) + save_config(st.session_state.current_dir, st.session_state.config['favorites']) + st.rerun() st.markdown("---") @@ -123,7 +138,7 @@ with st.sidebar: if not new_filename.endswith(".json"): new_filename += ".json" path = st.session_state.current_dir / new_filename if is_batch: - data = {"batch_data": []} + data = {KEY_BATCH_DATA: []} else: data = DEFAULTS.copy() if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""}) @@ -163,7 +178,7 @@ if selected_file_name: st.session_state.edit_history_idx = None # --- AUTO-SWITCH TAB LOGIC --- - is_batch = "batch_data" in data or isinstance(data, list) + is_batch = KEY_BATCH_DATA in data or isinstance(data, list) if is_batch: st.session_state.active_tab_name = "🚀 Batch Processor" else: diff --git a/history_tree.py b/history_tree.py index 2afb356..2b3f53b 100644 --- a/history_tree.py +++ b/history_tree.py @@ -1,16 +1,20 @@ import time import uuid +from typing import Any + +KEY_PROMPT_HISTORY = "prompt_history" + class HistoryTree: - def __init__(self, raw_data): - self.nodes = raw_data.get("nodes", {}) - self.branches = raw_data.get("branches", {"main": None}) - self.head_id = raw_data.get("head_id", None) - - if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes: - self._migrate_legacy(raw_data["prompt_history"]) + def __init__(self, raw_data: dict[str, Any]) -> None: + self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {}) + self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None}) + self.head_id: str | None = raw_data.get("head_id", None) - def _migrate_legacy(self, old_list): + if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes: + self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY]) + + def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None: parent = None for item in reversed(old_list): node_id = str(uuid.uuid4())[:8] @@ -22,7 +26,7 @@ class HistoryTree: self.branches["main"] = parent self.head_id = parent - def commit(self, data, note="Snapshot"): + def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str: new_id = str(uuid.uuid4())[:8] # Cycle detection: walk parent chain from head to verify no cycle @@ -56,17 +60,17 @@ class HistoryTree: self.head_id = new_id return new_id - def checkout(self, node_id): + def checkout(self, node_id: str) -> dict[str, Any] | None: if node_id in self.nodes: self.head_id = node_id return self.nodes[node_id]["data"] return None - def to_dict(self): + def to_dict(self) -> dict[str, Any]: return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id} # --- UPDATED GRAPH GENERATOR --- - def generate_graph(self, direction="LR"): + def generate_graph(self, direction: str = "LR") -> str: """ Generates Graphviz source. direction: "LR" (Horizontal) or "TB" (Vertical) diff --git a/json_loader.py b/json_loader.py index 2a43a6f..d158892 100644 --- a/json_loader.py +++ b/json_loader.py @@ -1,32 +1,36 @@ import json import os import logging +from typing import Any logger = logging.getLogger(__name__) -def to_float(val): +KEY_BATCH_DATA = "batch_data" + + +def to_float(val: Any) -> float: try: return float(val) except (ValueError, TypeError): return 0.0 -def to_int(val): +def to_int(val: Any) -> int: try: return int(float(val)) except (ValueError, TypeError): return 0 -def get_batch_item(data, sequence_number): +def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]: """Resolve batch item by sequence_number, clamping to valid range.""" - if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0: - idx = max(0, min(sequence_number - 1, len(data["batch_data"]) - 1)) + if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0: + idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1)) if sequence_number - 1 != idx: - logger.warning(f"Sequence {sequence_number} out of range (1-{len(data['batch_data'])}), clamped to {idx + 1}") - return data["batch_data"][idx] + logger.warning(f"Sequence {sequence_number} out of range (1-{len(data[KEY_BATCH_DATA])}), clamped to {idx + 1}") + return data[KEY_BATCH_DATA][idx] return data # --- Shared Helper --- -def read_json_data(json_path): +def read_json_data(json_path: str) -> dict[str, Any]: if not os.path.exists(json_path): logger.warning(f"File not found at {json_path}") return {} diff --git a/tab_batch.py b/tab_batch.py index 11be243..3ff3355 100644 --- a/tab_batch.py +++ b/tab_batch.py @@ -1,8 +1,8 @@ import streamlit as st import random import copy -from utils import DEFAULTS, save_json, load_json -from history_tree import HistoryTree +from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER +from history_tree import HistoryTree def create_batch_callback(original_filename, current_data, current_dir): new_name = f"batch_{original_filename}" @@ -13,15 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir): return first_item = current_data.copy() - if "prompt_history" in first_item: del first_item["prompt_history"] - if "history_tree" in first_item: del first_item["history_tree"] - - first_item["sequence_number"] = 1 - + if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY] + if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE] + + first_item[KEY_SEQUENCE_NUMBER] = 1 + new_data = { - "batch_data": [first_item], - "history_tree": {}, - "prompt_history": [] + KEY_BATCH_DATA: [first_item], + KEY_HISTORY_TREE: {}, + KEY_PROMPT_HISTORY: [] } save_json(new_path, new_data) @@ -30,7 +30,7 @@ def create_batch_callback(original_filename, current_data, current_dir): def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name): - is_batch_file = "batch_data" in data or isinstance(data, list) + is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list) if not is_batch_file: st.warning("This is a Single file. To use Batch mode, create a copy.") @@ -40,7 +40,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if 'restored_indicator' in st.session_state and st.session_state.restored_indicator: st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**") - batch_list = data.get("batch_data", []) + batch_list = data.get(KEY_BATCH_DATA, []) # --- ADD NEW SEQUENCE AREA --- st.subheader("Add New Sequence") @@ -53,7 +53,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi src_data, _ = load_json(current_dir / src_name) with ac2: - src_hist = src_data.get("prompt_history", []) + src_hist = src_data.get(KEY_PROMPT_HISTORY, []) h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else [] sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist") @@ -62,14 +62,14 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi def add_sequence(new_item): max_seq = 0 for s in batch_list: - if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"])) - new_item["sequence_number"] = max_seq + 1 - - for k in ["prompt_history", "history_tree", "note", "loras"]: + if KEY_SEQUENCE_NUMBER in s: max_seq = max(max_seq, int(s[KEY_SEQUENCE_NUMBER])) + new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1 + + for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]: if k in new_item: del new_item[k] batch_list.append(new_item) - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() @@ -79,7 +79,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"): item = DEFAULTS.copy() - flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data + flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data item.update(flat) add_sequence(item) @@ -107,7 +107,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"] standard_keys = { "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", - "camera", "flf", "sequence_number" + "camera", "flf", KEY_SEQUENCE_NUMBER } standard_keys.update(lora_keys) standard_keys.update([ @@ -116,7 +116,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi ]) for i, seq in enumerate(batch_list): - seq_num = seq.get("sequence_number", i+1) + seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1) prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}" with st.expander(f"🎬 Sequence #{seq_num}", expanded=False): @@ -127,13 +127,13 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with act_c1: if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True): item = DEFAULTS.copy() - flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data + flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data item.update(flat) - item["sequence_number"] = seq_num - for k in ["prompt_history", "history_tree"]: + item[KEY_SEQUENCE_NUMBER] = seq_num + for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]: if k in item: del item[k] batch_list[i] = item - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Copied!", icon="📥") @@ -145,10 +145,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True): new_seq = seq.copy() max_sn = 0 - for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0))) - new_seq["sequence_number"] = max_sn + 1 + for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0))) + new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1 batch_list.insert(i + 1, new_seq) - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Cloned to Next!", icon="👯") @@ -157,10 +157,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True): new_seq = seq.copy() max_sn = 0 - for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0))) - new_seq["sequence_number"] = max_sn + 1 + for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0))) + new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1 batch_list.append(new_seq) - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Cloned to End!", icon="⏬") @@ -170,9 +170,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with act_c3: if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True): single_data = seq.copy() - single_data["prompt_history"] = data.get("prompt_history", []) - single_data["history_tree"] = data.get("history_tree", {}) - if "sequence_number" in single_data: del single_data["sequence_number"] + single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, []) + single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {}) + if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER] save_json(file_path, single_data) st.toast("Converted to Single!", icon="✅") st.rerun() @@ -181,7 +181,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with act_c4: if st.button("🗑️", key=f"{prefix}_del", use_container_width=True): batch_list.pop(i) - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.rerun() @@ -194,7 +194,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn") with c2: - seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val") + seq[KEY_SEQUENCE_NUMBER] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val") s_row1, s_row2 = st.columns([3, 1]) seed_key = f"{prefix}_seed" @@ -320,17 +320,17 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with col_save: if st.button("💾 Save & Snap", use_container_width=True): - data["batch_data"] = batch_list + data[KEY_BATCH_DATA] = batch_list - tree_data = data.get("history_tree", {}) + tree_data = data.get(KEY_HISTORY_TREE, {}) htree = HistoryTree(tree_data) snapshot_payload = copy.deepcopy(data) - if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"] + if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE] htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) if 'restored_indicator' in st.session_state: diff --git a/tab_raw.py b/tab_raw.py index ff8c651..5458aaf 100644 --- a/tab_raw.py +++ b/tab_raw.py @@ -1,7 +1,7 @@ import streamlit as st import json import copy -from utils import save_json, get_file_mtime +from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY def render_raw_editor(data, file_path): st.subheader(f"💻 Raw Editor: {file_path.name}") @@ -20,8 +20,8 @@ def render_raw_editor(data, file_path): if hide_history: display_data = copy.deepcopy(data) # Safely remove heavy keys for the view only - if "history_tree" in display_data: del display_data["history_tree"] - if "prompt_history" in display_data: del display_data["prompt_history"] + if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE] + if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY] else: display_data = data @@ -51,10 +51,10 @@ def render_raw_editor(data, file_path): # 2. If we were in Safe Mode, we must merge the hidden history back in if hide_history: - if "history_tree" in data: - input_data["history_tree"] = data["history_tree"] - if "prompt_history" in data: - input_data["prompt_history"] = data["prompt_history"] + if KEY_HISTORY_TREE in data: + input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE] + if KEY_PROMPT_HISTORY in data: + input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY] # 3. Save to Disk save_json(file_path, input_data) diff --git a/tab_single.py b/tab_single.py index a05b7c0..32a19df 100644 --- a/tab_single.py +++ b/tab_single.py @@ -1,9 +1,9 @@ import streamlit as st import random -from utils import DEFAULTS, save_json, get_file_mtime +from utils import DEFAULTS, save_json, get_file_mtime, KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER def render_single_editor(data, file_path): - is_batch_file = "batch_data" in data or isinstance(data, list) + is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list) if is_batch_file: st.info("This is a batch file. Switch to the 'Batch Processor' tab.") @@ -63,7 +63,7 @@ def render_single_editor(data, file_path): # Explicitly track standard setting keys to exclude them from custom list standard_keys = { "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", - "camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token", + "camera", "flf", KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, "ui_reset_token", "model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler" } standard_keys.update(lora_keys) @@ -169,8 +169,8 @@ def render_single_editor(data, file_path): archive_note = st.text_input("Archive Note") if st.button("📦 Snapshot to History", use_container_width=True): entry = {"note": archive_note if archive_note else "Snapshot", **current_state} - if "prompt_history" not in data: data["prompt_history"] = [] - data["prompt_history"].insert(0, entry) + if KEY_PROMPT_HISTORY not in data: data[KEY_PROMPT_HISTORY] = [] + data[KEY_PROMPT_HISTORY].insert(0, entry) data.update(entry) save_json(file_path, data) st.session_state.last_mtime = get_file_mtime(file_path) @@ -181,7 +181,7 @@ def render_single_editor(data, file_path): # --- FULL HISTORY PANEL --- st.markdown("---") st.subheader("History") - history = data.get("prompt_history", []) + history = data.get(KEY_PROMPT_HISTORY, []) if not history: st.caption("No history yet.") diff --git a/tab_timeline.py b/tab_timeline.py index 4e1ce50..e319ac1 100644 --- a/tab_timeline.py +++ b/tab_timeline.py @@ -4,10 +4,10 @@ import json import graphviz import time from history_tree import HistoryTree -from utils import save_json +from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE def render_timeline_tab(data, file_path): - tree_data = data.get("history_tree", {}) + tree_data = data.get(KEY_HISTORY_TREE, {}) if not tree_data: st.info("No history timeline exists. Make some changes in the Editor first!") return @@ -61,13 +61,13 @@ def render_timeline_tab(data, file_path): if not is_head: if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"): # --- FIX: Cleanup 'batch_data' if restoring a Single File --- - if "batch_data" not in n["data"] and "batch_data" in data: - del data["batch_data"] + if KEY_BATCH_DATA not in n["data"] and KEY_BATCH_DATA in data: + del data[KEY_BATCH_DATA] # ------------------------------------------------------------- data.update(n["data"]) htree.head_id = n['id'] - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.session_state.ui_reset_token += 1 label = f"{n.get('note')} ({n['id'][:4]})" @@ -109,13 +109,13 @@ def render_timeline_tab(data, file_path): st.write(""); st.write("") if st.button("⏪ Restore Version", type="primary", use_container_width=True): # --- FIX: Cleanup 'batch_data' if restoring a Single File --- - if "batch_data" not in node_data and "batch_data" in data: - del data["batch_data"] + if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data: + del data[KEY_BATCH_DATA] # ------------------------------------------------------------- data.update(node_data) htree.head_id = selected_node['id'] - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.session_state.ui_reset_token += 1 label = f"{selected_node.get('note')} ({selected_node['id'][:4]})" @@ -128,7 +128,7 @@ def render_timeline_tab(data, file_path): new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", "")) if rn_col2.button("Update Label"): selected_node["note"] = new_label - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.rerun() @@ -152,7 +152,7 @@ def render_timeline_tab(data, file_path): htree.head_id = fallback["id"] else: htree.head_id = None - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.toast("Node Deleted", icon="🗑️") st.rerun() diff --git a/tab_timeline_wip.py b/tab_timeline_wip.py index 82b06c7..419676c 100644 --- a/tab_timeline_wip.py +++ b/tab_timeline_wip.py @@ -1,7 +1,7 @@ import streamlit as st import json from history_tree import HistoryTree -from utils import save_json +from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE try: from streamlit_agraph import agraph, Node, Edge, Config @@ -13,7 +13,7 @@ def render_timeline_wip(data, file_path): if not _HAS_AGRAPH: st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`") return - tree_data = data.get("history_tree", {}) + tree_data = data.get(KEY_HISTORY_TREE, {}) if not tree_data: st.info("No history timeline exists.") return @@ -108,14 +108,14 @@ def render_timeline_wip(data, file_path): st.write(""); st.write("") if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"): # --- FIX: Cleanup 'batch_data' if restoring a Single File --- - if "batch_data" not in node_data and "batch_data" in data: - del data["batch_data"] + if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data: + del data[KEY_BATCH_DATA] # ------------------------------------------------------------- data.update(node_data) htree.head_id = target_node_id - data["history_tree"] = htree.to_dict() + data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.session_state.ui_reset_token += 1 @@ -174,7 +174,7 @@ def render_timeline_wip(data, file_path): v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid") # --- DETECT BATCH VS SINGLE --- - batch_list = node_data.get("batch_data", []) + batch_list = node_data.get(KEY_BATCH_DATA, []) if batch_list and isinstance(batch_list, list) and len(batch_list) > 0: st.info(f"📚 This snapshot contains {len(batch_list)} sequences.") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0ac8525 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Add project root to sys.path so tests can import project modules +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..eea2c18 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1 @@ +[pytest] diff --git a/tests/test_history_tree.py b/tests/test_history_tree.py new file mode 100644 index 0000000..ea0821c --- /dev/null +++ b/tests/test_history_tree.py @@ -0,0 +1,67 @@ +import pytest +from history_tree import HistoryTree + + +def test_commit_creates_node_with_correct_parent(): + tree = HistoryTree({}) + id1 = tree.commit({"a": 1}, note="first") + id2 = tree.commit({"b": 2}, note="second") + + assert tree.nodes[id1]["parent"] is None + assert tree.nodes[id2]["parent"] == id1 + + +def test_checkout_returns_correct_data(): + tree = HistoryTree({}) + id1 = tree.commit({"val": 42}, note="snap") + result = tree.checkout(id1) + assert result == {"val": 42} + + +def test_checkout_nonexistent_returns_none(): + tree = HistoryTree({}) + assert tree.checkout("nonexistent") is None + + +def test_cycle_detection_raises(): + tree = HistoryTree({}) + id1 = tree.commit({"a": 1}) + # Manually introduce a cycle + tree.nodes[id1]["parent"] = id1 + with pytest.raises(ValueError, match="Cycle detected"): + tree.commit({"b": 2}) + + +def test_branch_creation_on_detached_head(): + tree = HistoryTree({}) + id1 = tree.commit({"a": 1}) + id2 = tree.commit({"b": 2}) + # Detach head by checking out a non-tip node + tree.checkout(id1) + # head_id is now id1, which is no longer a branch tip (main points to id2) + id3 = tree.commit({"c": 3}) + # A new branch should have been created + assert len(tree.branches) == 2 + assert tree.nodes[id3]["parent"] == id1 + + +def test_legacy_migration(): + legacy = { + "prompt_history": [ + {"note": "Entry A", "seed": 1}, + {"note": "Entry B", "seed": 2}, + ] + } + tree = HistoryTree(legacy) + assert len(tree.nodes) == 2 + assert tree.head_id is not None + assert tree.branches["main"] == tree.head_id + + +def test_to_dict_roundtrip(): + tree = HistoryTree({}) + tree.commit({"x": 1}, note="test") + d = tree.to_dict() + tree2 = HistoryTree(d) + assert tree2.head_id == tree.head_id + assert tree2.nodes == tree.nodes diff --git a/tests/test_json_loader.py b/tests/test_json_loader.py new file mode 100644 index 0000000..0f9b3e4 --- /dev/null +++ b/tests/test_json_loader.py @@ -0,0 +1,68 @@ +import json +import os + +import pytest + +from json_loader import to_float, to_int, get_batch_item, read_json_data + + +class TestToFloat: + def test_valid(self): + assert to_float("3.14") == 3.14 + assert to_float(5) == 5.0 + + def test_invalid(self): + assert to_float("abc") == 0.0 + + def test_none(self): + assert to_float(None) == 0.0 + + +class TestToInt: + def test_valid(self): + assert to_int("7") == 7 + assert to_int(3.9) == 3 + + def test_invalid(self): + assert to_int("xyz") == 0 + + def test_none(self): + assert to_int(None) == 0 + + +class TestGetBatchItem: + def test_valid_index(self): + data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]} + assert get_batch_item(data, 2) == {"a": 2} + + def test_clamp_high(self): + data = {"batch_data": [{"a": 1}, {"a": 2}]} + assert get_batch_item(data, 99) == {"a": 2} + + def test_clamp_low(self): + data = {"batch_data": [{"a": 1}, {"a": 2}]} + assert get_batch_item(data, 0) == {"a": 1} + + def test_no_batch_data(self): + data = {"key": "val"} + assert get_batch_item(data, 1) == data + + +class TestReadJsonData: + def test_missing_file(self, tmp_path): + assert read_json_data(str(tmp_path / "nope.json")) == {} + + def test_invalid_json(self, tmp_path): + p = tmp_path / "bad.json" + p.write_text("{broken") + assert read_json_data(str(p)) == {} + + def test_non_dict_json(self, tmp_path): + p = tmp_path / "list.json" + p.write_text(json.dumps([1, 2, 3])) + assert read_json_data(str(p)) == {} + + def test_valid(self, tmp_path): + p = tmp_path / "ok.json" + p.write_text(json.dumps({"key": "val"})) + assert read_json_data(str(p)) == {"key": "val"} diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b38ad9b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,68 @@ +import json +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +# Mock streamlit before importing utils +import sys +from unittest.mock import MagicMock +sys.modules.setdefault("streamlit", MagicMock()) + +from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS + + +def test_load_json_valid(tmp_path): + p = tmp_path / "test.json" + data = {"key": "value"} + p.write_text(json.dumps(data)) + result, mtime = load_json(p) + assert result == data + assert mtime > 0 + + +def test_load_json_missing(tmp_path): + p = tmp_path / "nope.json" + result, mtime = load_json(p) + assert result == DEFAULTS.copy() + assert mtime == 0 + + +def test_load_json_invalid(tmp_path): + p = tmp_path / "bad.json" + p.write_text("{not valid json") + result, mtime = load_json(p) + assert result == DEFAULTS.copy() + assert mtime == 0 + + +def test_save_json_atomic(tmp_path): + p = tmp_path / "out.json" + data = {"hello": "world"} + save_json(p, data) + assert p.exists() + assert not p.with_suffix(".json.tmp").exists() + assert json.loads(p.read_text()) == data + + +def test_save_json_overwrites(tmp_path): + p = tmp_path / "out.json" + save_json(p, {"a": 1}) + save_json(p, {"b": 2}) + assert json.loads(p.read_text()) == {"b": 2} + + +def test_get_file_mtime_existing(tmp_path): + p = tmp_path / "f.txt" + p.write_text("x") + assert get_file_mtime(p) > 0 + + +def test_get_file_mtime_missing(tmp_path): + assert get_file_mtime(tmp_path / "missing.txt") == 0 + + +def test_allowed_base_dir_is_set(): + assert ALLOWED_BASE_DIR is not None + assert isinstance(ALLOWED_BASE_DIR, Path) diff --git a/utils.py b/utils.py index f9879c8..092f248 100644 --- a/utils.py +++ b/utils.py @@ -1,9 +1,18 @@ import json import logging +import os import time from pathlib import Path +from typing import Any + import streamlit as st +# --- Magic String Keys --- +KEY_BATCH_DATA = "batch_data" +KEY_HISTORY_TREE = "history_tree" +KEY_PROMPT_HISTORY = "prompt_history" +KEY_SEQUENCE_NUMBER = "sequence_number" + # Configure logging for the application logging.basicConfig( level=logging.INFO, @@ -52,8 +61,8 @@ DEFAULTS = { CONFIG_FILE = Path(".editor_config.json") SNIPPETS_FILE = Path(".editor_snippets.json") -# Restrict directory navigation to this base path (resolve symlinks) -ALLOWED_BASE_DIR = Path.cwd().resolve() +# No restriction on directory navigation +ALLOWED_BASE_DIR = Path("/").resolve() def load_config(): """Loads the main editor configuration (Favorites, Last Dir, Servers).""" @@ -96,7 +105,7 @@ def save_snippets(snippets): with open(SNIPPETS_FILE, 'w') as f: json.dump(snippets, f, indent=4) -def load_json(path): +def load_json(path: str | Path) -> tuple[dict[str, Any], float]: path = Path(path) if not path.exists(): return DEFAULTS.copy(), 0 @@ -108,20 +117,23 @@ def load_json(path): st.error(f"Error loading JSON: {e}") return DEFAULTS.copy(), 0 -def save_json(path, data): - with open(path, 'w') as f: +def save_json(path: str | Path, data: dict[str, Any]) -> None: + path = Path(path) + tmp = path.with_suffix('.json.tmp') + with open(tmp, 'w') as f: json.dump(data, f, indent=4) + os.replace(tmp, path) -def get_file_mtime(path): +def get_file_mtime(path: str | Path) -> float: """Returns the modification time of a file, or 0 if it doesn't exist.""" path = Path(path) if path.exists(): return path.stat().st_mtime return 0 -def generate_templates(current_dir): +def generate_templates(current_dir: Path) -> None: """Creates dummy template files if folder is empty.""" save_json(current_dir / "template_i2v.json", DEFAULTS) - - batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]} + + batch_data = {KEY_BATCH_DATA: [DEFAULTS.copy(), DEFAULTS.copy()]} save_json(current_dir / "template_batch.json", batch_data)