Add atomic writes, magic string constants, unit tests, type hints, and fix navigation

- save_json() now writes to a temp file then uses os.replace() for atomic writes
- Replace hardcoded "batch_data", "history_tree", "prompt_history", "sequence_number"
  strings with constants (KEY_BATCH_DATA, etc.) across all modules
- Add 29 unit tests for history_tree, utils, and json_loader
- Add type hints to public functions in utils.py, json_loader.py, history_tree.py
- Remove ALLOWED_BASE_DIR restriction that blocked navigating outside app CWD
- Fix path text input not updating on navigation by using session state key
- Add unpin button () for removing pinned folders

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-02 12:44:31 +01:00
parent 326ae25ab2
commit b02bf124fb
15 changed files with 368 additions and 124 deletions

67
app.py
View File

@@ -5,7 +5,8 @@ from pathlib import Path
# --- Import Custom Modules ---
from utils import (
load_config, save_config, load_snippets, save_snippets,
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR,
KEY_BATCH_DATA, KEY_PROMPT_HISTORY,
)
from tab_single import render_single_editor
from tab_batch import render_batch_processor
@@ -47,37 +48,51 @@ with st.sidebar:
st.header("📂 Navigator")
# --- Path Navigator ---
new_path = st.text_input("Current Path", value=str(st.session_state.current_dir))
# Sync widget key with current_dir so the text input always reflects the actual path
if "nav_path_input" not in st.session_state:
st.session_state.nav_path_input = str(st.session_state.current_dir)
new_path = st.text_input("Current Path", key="nav_path_input")
if new_path != str(st.session_state.current_dir):
p = Path(new_path).resolve()
if p.exists() and p.is_dir():
# Restrict navigation to the allowed base directory
try:
p.relative_to(ALLOWED_BASE_DIR)
except ValueError:
st.error(f"Access denied: path must be under {ALLOWED_BASE_DIR}")
else:
st.session_state.current_dir = p
st.session_state.config['last_dir'] = str(p)
st.session_state.current_dir = p
st.session_state.config['last_dir'] = str(p)
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
st.rerun()
elif new_path.strip():
st.error(f"Path does not exist or is not a directory: {new_path}")
# --- Favorites System ---
pin_col, unpin_col = st.columns(2)
with pin_col:
if st.button("📌 Pin Folder", use_container_width=True):
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
st.rerun()
# --- Favorites System ---
if st.button("📌 Pin Current Folder"):
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
favorites = st.session_state.config['favorites']
if favorites:
fav_selection = st.radio(
"Jump to:",
["Select..."] + favorites,
index=0,
label_visibility="collapsed"
)
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
st.session_state.current_dir = Path(fav_selection)
st.session_state.nav_path_input = fav_selection
st.rerun()
fav_selection = st.radio(
"Jump to:",
["Select..."] + st.session_state.config['favorites'],
index=0,
label_visibility="collapsed"
)
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
st.session_state.current_dir = Path(fav_selection)
st.rerun()
# Unpin buttons for each favorite
for fav in favorites:
fc1, fc2 = st.columns([4, 1])
fc1.caption(fav)
if fc2.button("", key=f"unpin_{fav}"):
st.session_state.config['favorites'].remove(fav)
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
st.rerun()
st.markdown("---")
@@ -123,7 +138,7 @@ with st.sidebar:
if not new_filename.endswith(".json"): new_filename += ".json"
path = st.session_state.current_dir / new_filename
if is_batch:
data = {"batch_data": []}
data = {KEY_BATCH_DATA: []}
else:
data = DEFAULTS.copy()
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
@@ -163,7 +178,7 @@ if selected_file_name:
st.session_state.edit_history_idx = None
# --- AUTO-SWITCH TAB LOGIC ---
is_batch = "batch_data" in data or isinstance(data, list)
is_batch = KEY_BATCH_DATA in data or isinstance(data, list)
if is_batch:
st.session_state.active_tab_name = "🚀 Batch Processor"
else:

View File

@@ -1,16 +1,20 @@
import time
import uuid
from typing import Any
KEY_PROMPT_HISTORY = "prompt_history"
class HistoryTree:
def __init__(self, raw_data):
self.nodes = raw_data.get("nodes", {})
self.branches = raw_data.get("branches", {"main": None})
self.head_id = raw_data.get("head_id", None)
if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes:
self._migrate_legacy(raw_data["prompt_history"])
def __init__(self, raw_data: dict[str, Any]) -> None:
self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {})
self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None})
self.head_id: str | None = raw_data.get("head_id", None)
def _migrate_legacy(self, old_list):
if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes:
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
parent = None
for item in reversed(old_list):
node_id = str(uuid.uuid4())[:8]
@@ -22,7 +26,7 @@ class HistoryTree:
self.branches["main"] = parent
self.head_id = parent
def commit(self, data, note="Snapshot"):
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
new_id = str(uuid.uuid4())[:8]
# Cycle detection: walk parent chain from head to verify no cycle
@@ -56,17 +60,17 @@ class HistoryTree:
self.head_id = new_id
return new_id
def checkout(self, node_id):
def checkout(self, node_id: str) -> dict[str, Any] | None:
if node_id in self.nodes:
self.head_id = node_id
return self.nodes[node_id]["data"]
return None
def to_dict(self):
def to_dict(self) -> dict[str, Any]:
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
# --- UPDATED GRAPH GENERATOR ---
def generate_graph(self, direction="LR"):
def generate_graph(self, direction: str = "LR") -> str:
"""
Generates Graphviz source.
direction: "LR" (Horizontal) or "TB" (Vertical)

View File

@@ -1,32 +1,36 @@
import json
import os
import logging
from typing import Any
logger = logging.getLogger(__name__)
def to_float(val):
KEY_BATCH_DATA = "batch_data"
def to_float(val: Any) -> float:
try:
return float(val)
except (ValueError, TypeError):
return 0.0
def to_int(val):
def to_int(val: Any) -> int:
try:
return int(float(val))
except (ValueError, TypeError):
return 0
def get_batch_item(data, sequence_number):
def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]:
"""Resolve batch item by sequence_number, clamping to valid range."""
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
idx = max(0, min(sequence_number - 1, len(data["batch_data"]) - 1))
if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0:
idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1))
if sequence_number - 1 != idx:
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data['batch_data'])}), clamped to {idx + 1}")
return data["batch_data"][idx]
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data[KEY_BATCH_DATA])}), clamped to {idx + 1}")
return data[KEY_BATCH_DATA][idx]
return data
# --- Shared Helper ---
def read_json_data(json_path):
def read_json_data(json_path: str) -> dict[str, Any]:
if not os.path.exists(json_path):
logger.warning(f"File not found at {json_path}")
return {}

View File

@@ -1,8 +1,8 @@
import streamlit as st
import random
import copy
from utils import DEFAULTS, save_json, load_json
from history_tree import HistoryTree
from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
from history_tree import HistoryTree
def create_batch_callback(original_filename, current_data, current_dir):
new_name = f"batch_{original_filename}"
@@ -13,15 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir):
return
first_item = current_data.copy()
if "prompt_history" in first_item: del first_item["prompt_history"]
if "history_tree" in first_item: del first_item["history_tree"]
first_item["sequence_number"] = 1
if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY]
if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE]
first_item[KEY_SEQUENCE_NUMBER] = 1
new_data = {
"batch_data": [first_item],
"history_tree": {},
"prompt_history": []
KEY_BATCH_DATA: [first_item],
KEY_HISTORY_TREE: {},
KEY_PROMPT_HISTORY: []
}
save_json(new_path, new_data)
@@ -30,7 +30,7 @@ def create_batch_callback(original_filename, current_data, current_dir):
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
is_batch_file = "batch_data" in data or isinstance(data, list)
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
if not is_batch_file:
st.warning("This is a Single file. To use Batch mode, create a copy.")
@@ -40,7 +40,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
batch_list = data.get("batch_data", [])
batch_list = data.get(KEY_BATCH_DATA, [])
# --- ADD NEW SEQUENCE AREA ---
st.subheader("Add New Sequence")
@@ -53,7 +53,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
src_data, _ = load_json(current_dir / src_name)
with ac2:
src_hist = src_data.get("prompt_history", [])
src_hist = src_data.get(KEY_PROMPT_HISTORY, [])
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
@@ -62,14 +62,14 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
def add_sequence(new_item):
max_seq = 0
for s in batch_list:
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
new_item["sequence_number"] = max_seq + 1
for k in ["prompt_history", "history_tree", "note", "loras"]:
if KEY_SEQUENCE_NUMBER in s: max_seq = max(max_seq, int(s[KEY_SEQUENCE_NUMBER]))
new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]:
if k in new_item: del new_item[k]
batch_list.append(new_item)
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
st.session_state.ui_reset_token += 1
st.rerun()
@@ -79,7 +79,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if bc2.button(" From File", use_container_width=True, help=f"Copy {src_name}"):
item = DEFAULTS.copy()
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
item.update(flat)
add_sequence(item)
@@ -107,7 +107,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
"camera", "flf", "sequence_number"
"camera", "flf", KEY_SEQUENCE_NUMBER
}
standard_keys.update(lora_keys)
standard_keys.update([
@@ -116,7 +116,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
])
for i, seq in enumerate(batch_list):
seq_num = seq.get("sequence_number", i+1)
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1)
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
@@ -127,13 +127,13 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c1:
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
item = DEFAULTS.copy()
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
item.update(flat)
item["sequence_number"] = seq_num
for k in ["prompt_history", "history_tree"]:
item[KEY_SEQUENCE_NUMBER] = seq_num
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]:
if k in item: del item[k]
batch_list[i] = item
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
st.session_state.ui_reset_token += 1
st.toast("Copied!", icon="📥")
@@ -145,10 +145,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
new_seq = seq.copy()
max_sn = 0
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
new_seq["sequence_number"] = max_sn + 1
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
batch_list.insert(i + 1, new_seq)
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
st.session_state.ui_reset_token += 1
st.toast("Cloned to Next!", icon="👯")
@@ -157,10 +157,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
new_seq = seq.copy()
max_sn = 0
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
new_seq["sequence_number"] = max_sn + 1
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
batch_list.append(new_seq)
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
st.session_state.ui_reset_token += 1
st.toast("Cloned to End!", icon="")
@@ -170,9 +170,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c3:
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
single_data = seq.copy()
single_data["prompt_history"] = data.get("prompt_history", [])
single_data["history_tree"] = data.get("history_tree", {})
if "sequence_number" in single_data: del single_data["sequence_number"]
single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, [])
single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {})
if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER]
save_json(file_path, single_data)
st.toast("Converted to Single!", icon="")
st.rerun()
@@ -181,7 +181,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c4:
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
batch_list.pop(i)
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data)
st.rerun()
@@ -194,7 +194,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
with c2:
seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
seq[KEY_SEQUENCE_NUMBER] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
s_row1, s_row2 = st.columns([3, 1])
seed_key = f"{prefix}_seed"
@@ -320,17 +320,17 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with col_save:
if st.button("💾 Save & Snap", use_container_width=True):
data["batch_data"] = batch_list
data[KEY_BATCH_DATA] = batch_list
tree_data = data.get("history_tree", {})
tree_data = data.get(KEY_HISTORY_TREE, {})
htree = HistoryTree(tree_data)
snapshot_payload = copy.deepcopy(data)
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE]
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
if 'restored_indicator' in st.session_state:

View File

@@ -1,7 +1,7 @@
import streamlit as st
import json
import copy
from utils import save_json, get_file_mtime
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
def render_raw_editor(data, file_path):
st.subheader(f"💻 Raw Editor: {file_path.name}")
@@ -20,8 +20,8 @@ def render_raw_editor(data, file_path):
if hide_history:
display_data = copy.deepcopy(data)
# Safely remove heavy keys for the view only
if "history_tree" in display_data: del display_data["history_tree"]
if "prompt_history" in display_data: del display_data["prompt_history"]
if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE]
if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY]
else:
display_data = data
@@ -51,10 +51,10 @@ def render_raw_editor(data, file_path):
# 2. If we were in Safe Mode, we must merge the hidden history back in
if hide_history:
if "history_tree" in data:
input_data["history_tree"] = data["history_tree"]
if "prompt_history" in data:
input_data["prompt_history"] = data["prompt_history"]
if KEY_HISTORY_TREE in data:
input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
if KEY_PROMPT_HISTORY in data:
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
# 3. Save to Disk
save_json(file_path, input_data)

View File

@@ -1,9 +1,9 @@
import streamlit as st
import random
from utils import DEFAULTS, save_json, get_file_mtime
from utils import DEFAULTS, save_json, get_file_mtime, KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
def render_single_editor(data, file_path):
is_batch_file = "batch_data" in data or isinstance(data, list)
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
if is_batch_file:
st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
@@ -63,7 +63,7 @@ def render_single_editor(data, file_path):
# Explicitly track standard setting keys to exclude them from custom list
standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token",
"camera", "flf", KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, "ui_reset_token",
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
}
standard_keys.update(lora_keys)
@@ -169,8 +169,8 @@ def render_single_editor(data, file_path):
archive_note = st.text_input("Archive Note")
if st.button("📦 Snapshot to History", use_container_width=True):
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
if "prompt_history" not in data: data["prompt_history"] = []
data["prompt_history"].insert(0, entry)
if KEY_PROMPT_HISTORY not in data: data[KEY_PROMPT_HISTORY] = []
data[KEY_PROMPT_HISTORY].insert(0, entry)
data.update(entry)
save_json(file_path, data)
st.session_state.last_mtime = get_file_mtime(file_path)
@@ -181,7 +181,7 @@ def render_single_editor(data, file_path):
# --- FULL HISTORY PANEL ---
st.markdown("---")
st.subheader("History")
history = data.get("prompt_history", [])
history = data.get(KEY_PROMPT_HISTORY, [])
if not history:
st.caption("No history yet.")

View File

@@ -4,10 +4,10 @@ import json
import graphviz
import time
from history_tree import HistoryTree
from utils import save_json
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
def render_timeline_tab(data, file_path):
tree_data = data.get("history_tree", {})
tree_data = data.get(KEY_HISTORY_TREE, {})
if not tree_data:
st.info("No history timeline exists. Make some changes in the Editor first!")
return
@@ -61,13 +61,13 @@ def render_timeline_tab(data, file_path):
if not is_head:
if st.button("", key=f"log_rst_{n['id']}", help="Restore this version"):
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in n["data"] and "batch_data" in data:
del data["batch_data"]
if KEY_BATCH_DATA not in n["data"] and KEY_BATCH_DATA in data:
del data[KEY_BATCH_DATA]
# -------------------------------------------------------------
data.update(n["data"])
htree.head_id = n['id']
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
st.session_state.ui_reset_token += 1
label = f"{n.get('note')} ({n['id'][:4]})"
@@ -109,13 +109,13 @@ def render_timeline_tab(data, file_path):
st.write(""); st.write("")
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in node_data and "batch_data" in data:
del data["batch_data"]
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
del data[KEY_BATCH_DATA]
# -------------------------------------------------------------
data.update(node_data)
htree.head_id = selected_node['id']
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
st.session_state.ui_reset_token += 1
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
@@ -128,7 +128,7 @@ def render_timeline_tab(data, file_path):
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
if rn_col2.button("Update Label"):
selected_node["note"] = new_label
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
st.rerun()
@@ -152,7 +152,7 @@ def render_timeline_tab(data, file_path):
htree.head_id = fallback["id"]
else:
htree.head_id = None
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
st.toast("Node Deleted", icon="🗑️")
st.rerun()

View File

@@ -1,7 +1,7 @@
import streamlit as st
import json
from history_tree import HistoryTree
from utils import save_json
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
try:
from streamlit_agraph import agraph, Node, Edge, Config
@@ -13,7 +13,7 @@ def render_timeline_wip(data, file_path):
if not _HAS_AGRAPH:
st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`")
return
tree_data = data.get("history_tree", {})
tree_data = data.get(KEY_HISTORY_TREE, {})
if not tree_data:
st.info("No history timeline exists.")
return
@@ -108,14 +108,14 @@ def render_timeline_wip(data, file_path):
st.write(""); st.write("")
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in node_data and "batch_data" in data:
del data["batch_data"]
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
del data[KEY_BATCH_DATA]
# -------------------------------------------------------------
data.update(node_data)
htree.head_id = target_node_id
data["history_tree"] = htree.to_dict()
data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data)
st.session_state.ui_reset_token += 1
@@ -174,7 +174,7 @@ def render_timeline_wip(data, file_path):
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
# --- DETECT BATCH VS SINGLE ---
batch_list = node_data.get("batch_data", [])
batch_list = node_data.get(KEY_BATCH_DATA, [])
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")

0
tests/__init__.py Normal file
View File

5
tests/conftest.py Normal file
View File

@@ -0,0 +1,5 @@
import sys
from pathlib import Path
# Add project root to sys.path so tests can import project modules
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

1
tests/pytest.ini Normal file
View File

@@ -0,0 +1 @@
[pytest]

View File

@@ -0,0 +1,67 @@
import pytest
from history_tree import HistoryTree
def test_commit_creates_node_with_correct_parent():
tree = HistoryTree({})
id1 = tree.commit({"a": 1}, note="first")
id2 = tree.commit({"b": 2}, note="second")
assert tree.nodes[id1]["parent"] is None
assert tree.nodes[id2]["parent"] == id1
def test_checkout_returns_correct_data():
tree = HistoryTree({})
id1 = tree.commit({"val": 42}, note="snap")
result = tree.checkout(id1)
assert result == {"val": 42}
def test_checkout_nonexistent_returns_none():
tree = HistoryTree({})
assert tree.checkout("nonexistent") is None
def test_cycle_detection_raises():
tree = HistoryTree({})
id1 = tree.commit({"a": 1})
# Manually introduce a cycle
tree.nodes[id1]["parent"] = id1
with pytest.raises(ValueError, match="Cycle detected"):
tree.commit({"b": 2})
def test_branch_creation_on_detached_head():
tree = HistoryTree({})
id1 = tree.commit({"a": 1})
id2 = tree.commit({"b": 2})
# Detach head by checking out a non-tip node
tree.checkout(id1)
# head_id is now id1, which is no longer a branch tip (main points to id2)
id3 = tree.commit({"c": 3})
# A new branch should have been created
assert len(tree.branches) == 2
assert tree.nodes[id3]["parent"] == id1
def test_legacy_migration():
legacy = {
"prompt_history": [
{"note": "Entry A", "seed": 1},
{"note": "Entry B", "seed": 2},
]
}
tree = HistoryTree(legacy)
assert len(tree.nodes) == 2
assert tree.head_id is not None
assert tree.branches["main"] == tree.head_id
def test_to_dict_roundtrip():
tree = HistoryTree({})
tree.commit({"x": 1}, note="test")
d = tree.to_dict()
tree2 = HistoryTree(d)
assert tree2.head_id == tree.head_id
assert tree2.nodes == tree.nodes

68
tests/test_json_loader.py Normal file
View File

@@ -0,0 +1,68 @@
import json
import os
import pytest
from json_loader import to_float, to_int, get_batch_item, read_json_data
class TestToFloat:
def test_valid(self):
assert to_float("3.14") == 3.14
assert to_float(5) == 5.0
def test_invalid(self):
assert to_float("abc") == 0.0
def test_none(self):
assert to_float(None) == 0.0
class TestToInt:
def test_valid(self):
assert to_int("7") == 7
assert to_int(3.9) == 3
def test_invalid(self):
assert to_int("xyz") == 0
def test_none(self):
assert to_int(None) == 0
class TestGetBatchItem:
def test_valid_index(self):
data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]}
assert get_batch_item(data, 2) == {"a": 2}
def test_clamp_high(self):
data = {"batch_data": [{"a": 1}, {"a": 2}]}
assert get_batch_item(data, 99) == {"a": 2}
def test_clamp_low(self):
data = {"batch_data": [{"a": 1}, {"a": 2}]}
assert get_batch_item(data, 0) == {"a": 1}
def test_no_batch_data(self):
data = {"key": "val"}
assert get_batch_item(data, 1) == data
class TestReadJsonData:
def test_missing_file(self, tmp_path):
assert read_json_data(str(tmp_path / "nope.json")) == {}
def test_invalid_json(self, tmp_path):
p = tmp_path / "bad.json"
p.write_text("{broken")
assert read_json_data(str(p)) == {}
def test_non_dict_json(self, tmp_path):
p = tmp_path / "list.json"
p.write_text(json.dumps([1, 2, 3]))
assert read_json_data(str(p)) == {}
def test_valid(self, tmp_path):
p = tmp_path / "ok.json"
p.write_text(json.dumps({"key": "val"}))
assert read_json_data(str(p)) == {"key": "val"}

68
tests/test_utils.py Normal file
View File

@@ -0,0 +1,68 @@
import json
import os
from pathlib import Path
from unittest.mock import patch
import pytest
# Mock streamlit before importing utils
import sys
from unittest.mock import MagicMock
sys.modules.setdefault("streamlit", MagicMock())
from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS
def test_load_json_valid(tmp_path):
p = tmp_path / "test.json"
data = {"key": "value"}
p.write_text(json.dumps(data))
result, mtime = load_json(p)
assert result == data
assert mtime > 0
def test_load_json_missing(tmp_path):
p = tmp_path / "nope.json"
result, mtime = load_json(p)
assert result == DEFAULTS.copy()
assert mtime == 0
def test_load_json_invalid(tmp_path):
p = tmp_path / "bad.json"
p.write_text("{not valid json")
result, mtime = load_json(p)
assert result == DEFAULTS.copy()
assert mtime == 0
def test_save_json_atomic(tmp_path):
p = tmp_path / "out.json"
data = {"hello": "world"}
save_json(p, data)
assert p.exists()
assert not p.with_suffix(".json.tmp").exists()
assert json.loads(p.read_text()) == data
def test_save_json_overwrites(tmp_path):
p = tmp_path / "out.json"
save_json(p, {"a": 1})
save_json(p, {"b": 2})
assert json.loads(p.read_text()) == {"b": 2}
def test_get_file_mtime_existing(tmp_path):
p = tmp_path / "f.txt"
p.write_text("x")
assert get_file_mtime(p) > 0
def test_get_file_mtime_missing(tmp_path):
assert get_file_mtime(tmp_path / "missing.txt") == 0
def test_allowed_base_dir_is_set():
assert ALLOWED_BASE_DIR is not None
assert isinstance(ALLOWED_BASE_DIR, Path)

View File

@@ -1,9 +1,18 @@
import json
import logging
import os
import time
from pathlib import Path
from typing import Any
import streamlit as st
# --- Magic String Keys ---
KEY_BATCH_DATA = "batch_data"
KEY_HISTORY_TREE = "history_tree"
KEY_PROMPT_HISTORY = "prompt_history"
KEY_SEQUENCE_NUMBER = "sequence_number"
# Configure logging for the application
logging.basicConfig(
level=logging.INFO,
@@ -52,8 +61,8 @@ DEFAULTS = {
CONFIG_FILE = Path(".editor_config.json")
SNIPPETS_FILE = Path(".editor_snippets.json")
# Restrict directory navigation to this base path (resolve symlinks)
ALLOWED_BASE_DIR = Path.cwd().resolve()
# No restriction on directory navigation
ALLOWED_BASE_DIR = Path("/").resolve()
def load_config():
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
@@ -96,7 +105,7 @@ def save_snippets(snippets):
with open(SNIPPETS_FILE, 'w') as f:
json.dump(snippets, f, indent=4)
def load_json(path):
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
path = Path(path)
if not path.exists():
return DEFAULTS.copy(), 0
@@ -108,20 +117,23 @@ def load_json(path):
st.error(f"Error loading JSON: {e}")
return DEFAULTS.copy(), 0
def save_json(path, data):
with open(path, 'w') as f:
def save_json(path: str | Path, data: dict[str, Any]) -> None:
path = Path(path)
tmp = path.with_suffix('.json.tmp')
with open(tmp, 'w') as f:
json.dump(data, f, indent=4)
os.replace(tmp, path)
def get_file_mtime(path):
def get_file_mtime(path: str | Path) -> float:
"""Returns the modification time of a file, or 0 if it doesn't exist."""
path = Path(path)
if path.exists():
return path.stat().st_mtime
return 0
def generate_templates(current_dir):
def generate_templates(current_dir: Path) -> None:
"""Creates dummy template files if folder is empty."""
save_json(current_dir / "template_i2v.json", DEFAULTS)
batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]}
batch_data = {KEY_BATCH_DATA: [DEFAULTS.copy(), DEFAULTS.copy()]}
save_json(current_dir / "template_batch.json", batch_data)