Add atomic writes, magic string constants, unit tests, type hints, and fix navigation

- save_json() now writes to a temp file then uses os.replace() for atomic writes
- Replace hardcoded "batch_data", "history_tree", "prompt_history", "sequence_number"
  strings with constants (KEY_BATCH_DATA, etc.) across all modules
- Add 29 unit tests for history_tree, utils, and json_loader
- Add type hints to public functions in utils.py, json_loader.py, history_tree.py
- Remove ALLOWED_BASE_DIR restriction that blocked navigating outside app CWD
- Fix path text input not updating on navigation by using session state key
- Add unpin button () for removing pinned folders

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-02 12:44:31 +01:00
parent 326ae25ab2
commit b02bf124fb
15 changed files with 368 additions and 124 deletions

67
app.py
View File

@@ -5,7 +5,8 @@ from pathlib import Path
# --- Import Custom Modules --- # --- Import Custom Modules ---
from utils import ( from utils import (
load_config, save_config, load_snippets, save_snippets, load_config, save_config, load_snippets, save_snippets,
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR,
KEY_BATCH_DATA, KEY_PROMPT_HISTORY,
) )
from tab_single import render_single_editor from tab_single import render_single_editor
from tab_batch import render_batch_processor from tab_batch import render_batch_processor
@@ -47,37 +48,51 @@ with st.sidebar:
st.header("📂 Navigator") st.header("📂 Navigator")
# --- Path Navigator --- # --- Path Navigator ---
new_path = st.text_input("Current Path", value=str(st.session_state.current_dir)) # Sync widget key with current_dir so the text input always reflects the actual path
if "nav_path_input" not in st.session_state:
st.session_state.nav_path_input = str(st.session_state.current_dir)
new_path = st.text_input("Current Path", key="nav_path_input")
if new_path != str(st.session_state.current_dir): if new_path != str(st.session_state.current_dir):
p = Path(new_path).resolve() p = Path(new_path).resolve()
if p.exists() and p.is_dir(): if p.exists() and p.is_dir():
# Restrict navigation to the allowed base directory st.session_state.current_dir = p
try: st.session_state.config['last_dir'] = str(p)
p.relative_to(ALLOWED_BASE_DIR) save_config(st.session_state.current_dir, st.session_state.config['favorites'])
except ValueError: st.rerun()
st.error(f"Access denied: path must be under {ALLOWED_BASE_DIR}") elif new_path.strip():
else: st.error(f"Path does not exist or is not a directory: {new_path}")
st.session_state.current_dir = p
st.session_state.config['last_dir'] = str(p) # --- Favorites System ---
pin_col, unpin_col = st.columns(2)
with pin_col:
if st.button("📌 Pin Folder", use_container_width=True):
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
save_config(st.session_state.current_dir, st.session_state.config['favorites']) save_config(st.session_state.current_dir, st.session_state.config['favorites'])
st.rerun() st.rerun()
# --- Favorites System --- favorites = st.session_state.config['favorites']
if st.button("📌 Pin Current Folder"): if favorites:
if str(st.session_state.current_dir) not in st.session_state.config['favorites']: fav_selection = st.radio(
st.session_state.config['favorites'].append(str(st.session_state.current_dir)) "Jump to:",
save_config(st.session_state.current_dir, st.session_state.config['favorites']) ["Select..."] + favorites,
index=0,
label_visibility="collapsed"
)
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
st.session_state.current_dir = Path(fav_selection)
st.session_state.nav_path_input = fav_selection
st.rerun() st.rerun()
fav_selection = st.radio( # Unpin buttons for each favorite
"Jump to:", for fav in favorites:
["Select..."] + st.session_state.config['favorites'], fc1, fc2 = st.columns([4, 1])
index=0, fc1.caption(fav)
label_visibility="collapsed" if fc2.button("", key=f"unpin_{fav}"):
) st.session_state.config['favorites'].remove(fav)
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir): save_config(st.session_state.current_dir, st.session_state.config['favorites'])
st.session_state.current_dir = Path(fav_selection) st.rerun()
st.rerun()
st.markdown("---") st.markdown("---")
@@ -123,7 +138,7 @@ with st.sidebar:
if not new_filename.endswith(".json"): new_filename += ".json" if not new_filename.endswith(".json"): new_filename += ".json"
path = st.session_state.current_dir / new_filename path = st.session_state.current_dir / new_filename
if is_batch: if is_batch:
data = {"batch_data": []} data = {KEY_BATCH_DATA: []}
else: else:
data = DEFAULTS.copy() data = DEFAULTS.copy()
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""}) if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
@@ -163,7 +178,7 @@ if selected_file_name:
st.session_state.edit_history_idx = None st.session_state.edit_history_idx = None
# --- AUTO-SWITCH TAB LOGIC --- # --- AUTO-SWITCH TAB LOGIC ---
is_batch = "batch_data" in data or isinstance(data, list) is_batch = KEY_BATCH_DATA in data or isinstance(data, list)
if is_batch: if is_batch:
st.session_state.active_tab_name = "🚀 Batch Processor" st.session_state.active_tab_name = "🚀 Batch Processor"
else: else:

View File

@@ -1,16 +1,20 @@
import time import time
import uuid import uuid
from typing import Any
KEY_PROMPT_HISTORY = "prompt_history"
class HistoryTree: class HistoryTree:
def __init__(self, raw_data): def __init__(self, raw_data: dict[str, Any]) -> None:
self.nodes = raw_data.get("nodes", {}) self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {})
self.branches = raw_data.get("branches", {"main": None}) self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None})
self.head_id = raw_data.get("head_id", None) self.head_id: str | None = raw_data.get("head_id", None)
if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes:
self._migrate_legacy(raw_data["prompt_history"])
def _migrate_legacy(self, old_list): if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes:
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
parent = None parent = None
for item in reversed(old_list): for item in reversed(old_list):
node_id = str(uuid.uuid4())[:8] node_id = str(uuid.uuid4())[:8]
@@ -22,7 +26,7 @@ class HistoryTree:
self.branches["main"] = parent self.branches["main"] = parent
self.head_id = parent self.head_id = parent
def commit(self, data, note="Snapshot"): def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
new_id = str(uuid.uuid4())[:8] new_id = str(uuid.uuid4())[:8]
# Cycle detection: walk parent chain from head to verify no cycle # Cycle detection: walk parent chain from head to verify no cycle
@@ -56,17 +60,17 @@ class HistoryTree:
self.head_id = new_id self.head_id = new_id
return new_id return new_id
def checkout(self, node_id): def checkout(self, node_id: str) -> dict[str, Any] | None:
if node_id in self.nodes: if node_id in self.nodes:
self.head_id = node_id self.head_id = node_id
return self.nodes[node_id]["data"] return self.nodes[node_id]["data"]
return None return None
def to_dict(self): def to_dict(self) -> dict[str, Any]:
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id} return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
# --- UPDATED GRAPH GENERATOR --- # --- UPDATED GRAPH GENERATOR ---
def generate_graph(self, direction="LR"): def generate_graph(self, direction: str = "LR") -> str:
""" """
Generates Graphviz source. Generates Graphviz source.
direction: "LR" (Horizontal) or "TB" (Vertical) direction: "LR" (Horizontal) or "TB" (Vertical)

View File

@@ -1,32 +1,36 @@
import json import json
import os import os
import logging import logging
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def to_float(val): KEY_BATCH_DATA = "batch_data"
def to_float(val: Any) -> float:
try: try:
return float(val) return float(val)
except (ValueError, TypeError): except (ValueError, TypeError):
return 0.0 return 0.0
def to_int(val): def to_int(val: Any) -> int:
try: try:
return int(float(val)) return int(float(val))
except (ValueError, TypeError): except (ValueError, TypeError):
return 0 return 0
def get_batch_item(data, sequence_number): def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]:
"""Resolve batch item by sequence_number, clamping to valid range.""" """Resolve batch item by sequence_number, clamping to valid range."""
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0: if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0:
idx = max(0, min(sequence_number - 1, len(data["batch_data"]) - 1)) idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1))
if sequence_number - 1 != idx: if sequence_number - 1 != idx:
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data['batch_data'])}), clamped to {idx + 1}") logger.warning(f"Sequence {sequence_number} out of range (1-{len(data[KEY_BATCH_DATA])}), clamped to {idx + 1}")
return data["batch_data"][idx] return data[KEY_BATCH_DATA][idx]
return data return data
# --- Shared Helper --- # --- Shared Helper ---
def read_json_data(json_path): def read_json_data(json_path: str) -> dict[str, Any]:
if not os.path.exists(json_path): if not os.path.exists(json_path):
logger.warning(f"File not found at {json_path}") logger.warning(f"File not found at {json_path}")
return {} return {}

View File

@@ -1,8 +1,8 @@
import streamlit as st import streamlit as st
import random import random
import copy import copy
from utils import DEFAULTS, save_json, load_json from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
from history_tree import HistoryTree from history_tree import HistoryTree
def create_batch_callback(original_filename, current_data, current_dir): def create_batch_callback(original_filename, current_data, current_dir):
new_name = f"batch_{original_filename}" new_name = f"batch_{original_filename}"
@@ -13,15 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir):
return return
first_item = current_data.copy() first_item = current_data.copy()
if "prompt_history" in first_item: del first_item["prompt_history"] if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY]
if "history_tree" in first_item: del first_item["history_tree"] if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE]
first_item["sequence_number"] = 1 first_item[KEY_SEQUENCE_NUMBER] = 1
new_data = { new_data = {
"batch_data": [first_item], KEY_BATCH_DATA: [first_item],
"history_tree": {}, KEY_HISTORY_TREE: {},
"prompt_history": [] KEY_PROMPT_HISTORY: []
} }
save_json(new_path, new_data) save_json(new_path, new_data)
@@ -30,7 +30,7 @@ def create_batch_callback(original_filename, current_data, current_dir):
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name): def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
is_batch_file = "batch_data" in data or isinstance(data, list) is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
if not is_batch_file: if not is_batch_file:
st.warning("This is a Single file. To use Batch mode, create a copy.") st.warning("This is a Single file. To use Batch mode, create a copy.")
@@ -40,7 +40,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator: if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**") st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
batch_list = data.get("batch_data", []) batch_list = data.get(KEY_BATCH_DATA, [])
# --- ADD NEW SEQUENCE AREA --- # --- ADD NEW SEQUENCE AREA ---
st.subheader("Add New Sequence") st.subheader("Add New Sequence")
@@ -53,7 +53,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
src_data, _ = load_json(current_dir / src_name) src_data, _ = load_json(current_dir / src_name)
with ac2: with ac2:
src_hist = src_data.get("prompt_history", []) src_hist = src_data.get(KEY_PROMPT_HISTORY, [])
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else [] h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist") sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
@@ -62,14 +62,14 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
def add_sequence(new_item): def add_sequence(new_item):
max_seq = 0 max_seq = 0
for s in batch_list: for s in batch_list:
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"])) if KEY_SEQUENCE_NUMBER in s: max_seq = max(max_seq, int(s[KEY_SEQUENCE_NUMBER]))
new_item["sequence_number"] = max_seq + 1 new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1
for k in ["prompt_history", "history_tree", "note", "loras"]: for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]:
if k in new_item: del new_item[k] if k in new_item: del new_item[k]
batch_list.append(new_item) batch_list.append(new_item)
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
st.rerun() st.rerun()
@@ -79,7 +79,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if bc2.button(" From File", use_container_width=True, help=f"Copy {src_name}"): if bc2.button(" From File", use_container_width=True, help=f"Copy {src_name}"):
item = DEFAULTS.copy() item = DEFAULTS.copy()
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
item.update(flat) item.update(flat)
add_sequence(item) add_sequence(item)
@@ -107,7 +107,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"] lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
standard_keys = { standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
"camera", "flf", "sequence_number" "camera", "flf", KEY_SEQUENCE_NUMBER
} }
standard_keys.update(lora_keys) standard_keys.update(lora_keys)
standard_keys.update([ standard_keys.update([
@@ -116,7 +116,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
]) ])
for i, seq in enumerate(batch_list): for i, seq in enumerate(batch_list):
seq_num = seq.get("sequence_number", i+1) seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1)
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}" prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False): with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
@@ -127,13 +127,13 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c1: with act_c1:
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True): if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
item = DEFAULTS.copy() item = DEFAULTS.copy()
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
item.update(flat) item.update(flat)
item["sequence_number"] = seq_num item[KEY_SEQUENCE_NUMBER] = seq_num
for k in ["prompt_history", "history_tree"]: for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]:
if k in item: del item[k] if k in item: del item[k]
batch_list[i] = item batch_list[i] = item
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
st.toast("Copied!", icon="📥") st.toast("Copied!", icon="📥")
@@ -145,10 +145,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True): if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
new_seq = seq.copy() new_seq = seq.copy()
max_sn = 0 max_sn = 0
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0))) for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
new_seq["sequence_number"] = max_sn + 1 new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
batch_list.insert(i + 1, new_seq) batch_list.insert(i + 1, new_seq)
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
st.toast("Cloned to Next!", icon="👯") st.toast("Cloned to Next!", icon="👯")
@@ -157,10 +157,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True): if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
new_seq = seq.copy() new_seq = seq.copy()
max_sn = 0 max_sn = 0
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0))) for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
new_seq["sequence_number"] = max_sn + 1 new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
batch_list.append(new_seq) batch_list.append(new_seq)
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
st.toast("Cloned to End!", icon="") st.toast("Cloned to End!", icon="")
@@ -170,9 +170,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c3: with act_c3:
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True): if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
single_data = seq.copy() single_data = seq.copy()
single_data["prompt_history"] = data.get("prompt_history", []) single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, [])
single_data["history_tree"] = data.get("history_tree", {}) single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {})
if "sequence_number" in single_data: del single_data["sequence_number"] if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER]
save_json(file_path, single_data) save_json(file_path, single_data)
st.toast("Converted to Single!", icon="") st.toast("Converted to Single!", icon="")
st.rerun() st.rerun()
@@ -181,7 +181,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with act_c4: with act_c4:
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True): if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
batch_list.pop(i) batch_list.pop(i)
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) save_json(file_path, data)
st.rerun() st.rerun()
@@ -194,7 +194,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn") seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
with c2: with c2:
seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val") seq[KEY_SEQUENCE_NUMBER] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
s_row1, s_row2 = st.columns([3, 1]) s_row1, s_row2 = st.columns([3, 1])
seed_key = f"{prefix}_seed" seed_key = f"{prefix}_seed"
@@ -320,17 +320,17 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with col_save: with col_save:
if st.button("💾 Save & Snap", use_container_width=True): if st.button("💾 Save & Snap", use_container_width=True):
data["batch_data"] = batch_list data[KEY_BATCH_DATA] = batch_list
tree_data = data.get("history_tree", {}) tree_data = data.get(KEY_HISTORY_TREE, {})
htree = HistoryTree(tree_data) htree = HistoryTree(tree_data)
snapshot_payload = copy.deepcopy(data) snapshot_payload = copy.deepcopy(data)
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"] if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE]
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
if 'restored_indicator' in st.session_state: if 'restored_indicator' in st.session_state:

View File

@@ -1,7 +1,7 @@
import streamlit as st import streamlit as st
import json import json
import copy import copy
from utils import save_json, get_file_mtime from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
def render_raw_editor(data, file_path): def render_raw_editor(data, file_path):
st.subheader(f"💻 Raw Editor: {file_path.name}") st.subheader(f"💻 Raw Editor: {file_path.name}")
@@ -20,8 +20,8 @@ def render_raw_editor(data, file_path):
if hide_history: if hide_history:
display_data = copy.deepcopy(data) display_data = copy.deepcopy(data)
# Safely remove heavy keys for the view only # Safely remove heavy keys for the view only
if "history_tree" in display_data: del display_data["history_tree"] if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE]
if "prompt_history" in display_data: del display_data["prompt_history"] if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY]
else: else:
display_data = data display_data = data
@@ -51,10 +51,10 @@ def render_raw_editor(data, file_path):
# 2. If we were in Safe Mode, we must merge the hidden history back in # 2. If we were in Safe Mode, we must merge the hidden history back in
if hide_history: if hide_history:
if "history_tree" in data: if KEY_HISTORY_TREE in data:
input_data["history_tree"] = data["history_tree"] input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
if "prompt_history" in data: if KEY_PROMPT_HISTORY in data:
input_data["prompt_history"] = data["prompt_history"] input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
# 3. Save to Disk # 3. Save to Disk
save_json(file_path, input_data) save_json(file_path, input_data)

View File

@@ -1,9 +1,9 @@
import streamlit as st import streamlit as st
import random import random
from utils import DEFAULTS, save_json, get_file_mtime from utils import DEFAULTS, save_json, get_file_mtime, KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
def render_single_editor(data, file_path): def render_single_editor(data, file_path):
is_batch_file = "batch_data" in data or isinstance(data, list) is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
if is_batch_file: if is_batch_file:
st.info("This is a batch file. Switch to the 'Batch Processor' tab.") st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
@@ -63,7 +63,7 @@ def render_single_editor(data, file_path):
# Explicitly track standard setting keys to exclude them from custom list # Explicitly track standard setting keys to exclude them from custom list
standard_keys = { standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token", "camera", "flf", KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, "ui_reset_token",
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler" "model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
} }
standard_keys.update(lora_keys) standard_keys.update(lora_keys)
@@ -169,8 +169,8 @@ def render_single_editor(data, file_path):
archive_note = st.text_input("Archive Note") archive_note = st.text_input("Archive Note")
if st.button("📦 Snapshot to History", use_container_width=True): if st.button("📦 Snapshot to History", use_container_width=True):
entry = {"note": archive_note if archive_note else "Snapshot", **current_state} entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
if "prompt_history" not in data: data["prompt_history"] = [] if KEY_PROMPT_HISTORY not in data: data[KEY_PROMPT_HISTORY] = []
data["prompt_history"].insert(0, entry) data[KEY_PROMPT_HISTORY].insert(0, entry)
data.update(entry) data.update(entry)
save_json(file_path, data) save_json(file_path, data)
st.session_state.last_mtime = get_file_mtime(file_path) st.session_state.last_mtime = get_file_mtime(file_path)
@@ -181,7 +181,7 @@ def render_single_editor(data, file_path):
# --- FULL HISTORY PANEL --- # --- FULL HISTORY PANEL ---
st.markdown("---") st.markdown("---")
st.subheader("History") st.subheader("History")
history = data.get("prompt_history", []) history = data.get(KEY_PROMPT_HISTORY, [])
if not history: if not history:
st.caption("No history yet.") st.caption("No history yet.")

View File

@@ -4,10 +4,10 @@ import json
import graphviz import graphviz
import time import time
from history_tree import HistoryTree from history_tree import HistoryTree
from utils import save_json from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
def render_timeline_tab(data, file_path): def render_timeline_tab(data, file_path):
tree_data = data.get("history_tree", {}) tree_data = data.get(KEY_HISTORY_TREE, {})
if not tree_data: if not tree_data:
st.info("No history timeline exists. Make some changes in the Editor first!") st.info("No history timeline exists. Make some changes in the Editor first!")
return return
@@ -61,13 +61,13 @@ def render_timeline_tab(data, file_path):
if not is_head: if not is_head:
if st.button("", key=f"log_rst_{n['id']}", help="Restore this version"): if st.button("", key=f"log_rst_{n['id']}", help="Restore this version"):
# --- FIX: Cleanup 'batch_data' if restoring a Single File --- # --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in n["data"] and "batch_data" in data: if KEY_BATCH_DATA not in n["data"] and KEY_BATCH_DATA in data:
del data["batch_data"] del data[KEY_BATCH_DATA]
# ------------------------------------------------------------- # -------------------------------------------------------------
data.update(n["data"]) data.update(n["data"])
htree.head_id = n['id'] htree.head_id = n['id']
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
label = f"{n.get('note')} ({n['id'][:4]})" label = f"{n.get('note')} ({n['id'][:4]})"
@@ -109,13 +109,13 @@ def render_timeline_tab(data, file_path):
st.write(""); st.write("") st.write(""); st.write("")
if st.button("⏪ Restore Version", type="primary", use_container_width=True): if st.button("⏪ Restore Version", type="primary", use_container_width=True):
# --- FIX: Cleanup 'batch_data' if restoring a Single File --- # --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in node_data and "batch_data" in data: if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
del data["batch_data"] del data[KEY_BATCH_DATA]
# ------------------------------------------------------------- # -------------------------------------------------------------
data.update(node_data) data.update(node_data)
htree.head_id = selected_node['id'] htree.head_id = selected_node['id']
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})" label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
@@ -128,7 +128,7 @@ def render_timeline_tab(data, file_path):
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", "")) new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
if rn_col2.button("Update Label"): if rn_col2.button("Update Label"):
selected_node["note"] = new_label selected_node["note"] = new_label
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
st.rerun() st.rerun()
@@ -152,7 +152,7 @@ def render_timeline_tab(data, file_path):
htree.head_id = fallback["id"] htree.head_id = fallback["id"]
else: else:
htree.head_id = None htree.head_id = None
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
st.toast("Node Deleted", icon="🗑️") st.toast("Node Deleted", icon="🗑️")
st.rerun() st.rerun()

View File

@@ -1,7 +1,7 @@
import streamlit as st import streamlit as st
import json import json
from history_tree import HistoryTree from history_tree import HistoryTree
from utils import save_json from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
try: try:
from streamlit_agraph import agraph, Node, Edge, Config from streamlit_agraph import agraph, Node, Edge, Config
@@ -13,7 +13,7 @@ def render_timeline_wip(data, file_path):
if not _HAS_AGRAPH: if not _HAS_AGRAPH:
st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`") st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`")
return return
tree_data = data.get("history_tree", {}) tree_data = data.get(KEY_HISTORY_TREE, {})
if not tree_data: if not tree_data:
st.info("No history timeline exists.") st.info("No history timeline exists.")
return return
@@ -108,14 +108,14 @@ def render_timeline_wip(data, file_path):
st.write(""); st.write("") st.write(""); st.write("")
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"): if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
# --- FIX: Cleanup 'batch_data' if restoring a Single File --- # --- FIX: Cleanup 'batch_data' if restoring a Single File ---
if "batch_data" not in node_data and "batch_data" in data: if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
del data["batch_data"] del data[KEY_BATCH_DATA]
# ------------------------------------------------------------- # -------------------------------------------------------------
data.update(node_data) data.update(node_data)
htree.head_id = target_node_id htree.head_id = target_node_id
data["history_tree"] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
@@ -174,7 +174,7 @@ def render_timeline_wip(data, file_path):
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid") v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
# --- DETECT BATCH VS SINGLE --- # --- DETECT BATCH VS SINGLE ---
batch_list = node_data.get("batch_data", []) batch_list = node_data.get(KEY_BATCH_DATA, [])
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0: if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.") st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")

0
tests/__init__.py Normal file
View File

5
tests/conftest.py Normal file
View File

@@ -0,0 +1,5 @@
import sys
from pathlib import Path
# Add project root to sys.path so tests can import project modules
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

1
tests/pytest.ini Normal file
View File

@@ -0,0 +1 @@
[pytest]

View File

@@ -0,0 +1,67 @@
import pytest
from history_tree import HistoryTree
def test_commit_creates_node_with_correct_parent():
tree = HistoryTree({})
id1 = tree.commit({"a": 1}, note="first")
id2 = tree.commit({"b": 2}, note="second")
assert tree.nodes[id1]["parent"] is None
assert tree.nodes[id2]["parent"] == id1
def test_checkout_returns_correct_data():
tree = HistoryTree({})
id1 = tree.commit({"val": 42}, note="snap")
result = tree.checkout(id1)
assert result == {"val": 42}
def test_checkout_nonexistent_returns_none():
tree = HistoryTree({})
assert tree.checkout("nonexistent") is None
def test_cycle_detection_raises():
tree = HistoryTree({})
id1 = tree.commit({"a": 1})
# Manually introduce a cycle
tree.nodes[id1]["parent"] = id1
with pytest.raises(ValueError, match="Cycle detected"):
tree.commit({"b": 2})
def test_branch_creation_on_detached_head():
tree = HistoryTree({})
id1 = tree.commit({"a": 1})
id2 = tree.commit({"b": 2})
# Detach head by checking out a non-tip node
tree.checkout(id1)
# head_id is now id1, which is no longer a branch tip (main points to id2)
id3 = tree.commit({"c": 3})
# A new branch should have been created
assert len(tree.branches) == 2
assert tree.nodes[id3]["parent"] == id1
def test_legacy_migration():
legacy = {
"prompt_history": [
{"note": "Entry A", "seed": 1},
{"note": "Entry B", "seed": 2},
]
}
tree = HistoryTree(legacy)
assert len(tree.nodes) == 2
assert tree.head_id is not None
assert tree.branches["main"] == tree.head_id
def test_to_dict_roundtrip():
tree = HistoryTree({})
tree.commit({"x": 1}, note="test")
d = tree.to_dict()
tree2 = HistoryTree(d)
assert tree2.head_id == tree.head_id
assert tree2.nodes == tree.nodes

68
tests/test_json_loader.py Normal file
View File

@@ -0,0 +1,68 @@
import json
import os
import pytest
from json_loader import to_float, to_int, get_batch_item, read_json_data
class TestToFloat:
def test_valid(self):
assert to_float("3.14") == 3.14
assert to_float(5) == 5.0
def test_invalid(self):
assert to_float("abc") == 0.0
def test_none(self):
assert to_float(None) == 0.0
class TestToInt:
def test_valid(self):
assert to_int("7") == 7
assert to_int(3.9) == 3
def test_invalid(self):
assert to_int("xyz") == 0
def test_none(self):
assert to_int(None) == 0
class TestGetBatchItem:
def test_valid_index(self):
data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]}
assert get_batch_item(data, 2) == {"a": 2}
def test_clamp_high(self):
data = {"batch_data": [{"a": 1}, {"a": 2}]}
assert get_batch_item(data, 99) == {"a": 2}
def test_clamp_low(self):
data = {"batch_data": [{"a": 1}, {"a": 2}]}
assert get_batch_item(data, 0) == {"a": 1}
def test_no_batch_data(self):
data = {"key": "val"}
assert get_batch_item(data, 1) == data
class TestReadJsonData:
def test_missing_file(self, tmp_path):
assert read_json_data(str(tmp_path / "nope.json")) == {}
def test_invalid_json(self, tmp_path):
p = tmp_path / "bad.json"
p.write_text("{broken")
assert read_json_data(str(p)) == {}
def test_non_dict_json(self, tmp_path):
p = tmp_path / "list.json"
p.write_text(json.dumps([1, 2, 3]))
assert read_json_data(str(p)) == {}
def test_valid(self, tmp_path):
p = tmp_path / "ok.json"
p.write_text(json.dumps({"key": "val"}))
assert read_json_data(str(p)) == {"key": "val"}

68
tests/test_utils.py Normal file
View File

@@ -0,0 +1,68 @@
import json
import os
from pathlib import Path
from unittest.mock import patch
import pytest
# Mock streamlit before importing utils
import sys
from unittest.mock import MagicMock
sys.modules.setdefault("streamlit", MagicMock())
from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS
def test_load_json_valid(tmp_path):
p = tmp_path / "test.json"
data = {"key": "value"}
p.write_text(json.dumps(data))
result, mtime = load_json(p)
assert result == data
assert mtime > 0
def test_load_json_missing(tmp_path):
p = tmp_path / "nope.json"
result, mtime = load_json(p)
assert result == DEFAULTS.copy()
assert mtime == 0
def test_load_json_invalid(tmp_path):
p = tmp_path / "bad.json"
p.write_text("{not valid json")
result, mtime = load_json(p)
assert result == DEFAULTS.copy()
assert mtime == 0
def test_save_json_atomic(tmp_path):
p = tmp_path / "out.json"
data = {"hello": "world"}
save_json(p, data)
assert p.exists()
assert not p.with_suffix(".json.tmp").exists()
assert json.loads(p.read_text()) == data
def test_save_json_overwrites(tmp_path):
p = tmp_path / "out.json"
save_json(p, {"a": 1})
save_json(p, {"b": 2})
assert json.loads(p.read_text()) == {"b": 2}
def test_get_file_mtime_existing(tmp_path):
p = tmp_path / "f.txt"
p.write_text("x")
assert get_file_mtime(p) > 0
def test_get_file_mtime_missing(tmp_path):
assert get_file_mtime(tmp_path / "missing.txt") == 0
def test_allowed_base_dir_is_set():
assert ALLOWED_BASE_DIR is not None
assert isinstance(ALLOWED_BASE_DIR, Path)

View File

@@ -1,9 +1,18 @@
import json import json
import logging import logging
import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Any
import streamlit as st import streamlit as st
# --- Magic String Keys ---
KEY_BATCH_DATA = "batch_data"
KEY_HISTORY_TREE = "history_tree"
KEY_PROMPT_HISTORY = "prompt_history"
KEY_SEQUENCE_NUMBER = "sequence_number"
# Configure logging for the application # Configure logging for the application
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -52,8 +61,8 @@ DEFAULTS = {
CONFIG_FILE = Path(".editor_config.json") CONFIG_FILE = Path(".editor_config.json")
SNIPPETS_FILE = Path(".editor_snippets.json") SNIPPETS_FILE = Path(".editor_snippets.json")
# Restrict directory navigation to this base path (resolve symlinks) # No restriction on directory navigation
ALLOWED_BASE_DIR = Path.cwd().resolve() ALLOWED_BASE_DIR = Path("/").resolve()
def load_config(): def load_config():
"""Loads the main editor configuration (Favorites, Last Dir, Servers).""" """Loads the main editor configuration (Favorites, Last Dir, Servers)."""
@@ -96,7 +105,7 @@ def save_snippets(snippets):
with open(SNIPPETS_FILE, 'w') as f: with open(SNIPPETS_FILE, 'w') as f:
json.dump(snippets, f, indent=4) json.dump(snippets, f, indent=4)
def load_json(path): def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
path = Path(path) path = Path(path)
if not path.exists(): if not path.exists():
return DEFAULTS.copy(), 0 return DEFAULTS.copy(), 0
@@ -108,20 +117,23 @@ def load_json(path):
st.error(f"Error loading JSON: {e}") st.error(f"Error loading JSON: {e}")
return DEFAULTS.copy(), 0 return DEFAULTS.copy(), 0
def save_json(path, data): def save_json(path: str | Path, data: dict[str, Any]) -> None:
with open(path, 'w') as f: path = Path(path)
tmp = path.with_suffix('.json.tmp')
with open(tmp, 'w') as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
os.replace(tmp, path)
def get_file_mtime(path): def get_file_mtime(path: str | Path) -> float:
"""Returns the modification time of a file, or 0 if it doesn't exist.""" """Returns the modification time of a file, or 0 if it doesn't exist."""
path = Path(path) path = Path(path)
if path.exists(): if path.exists():
return path.stat().st_mtime return path.stat().st_mtime
return 0 return 0
def generate_templates(current_dir): def generate_templates(current_dir: Path) -> None:
"""Creates dummy template files if folder is empty.""" """Creates dummy template files if folder is empty."""
save_json(current_dir / "template_i2v.json", DEFAULTS) save_json(current_dir / "template_i2v.json", DEFAULTS)
batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]} batch_data = {KEY_BATCH_DATA: [DEFAULTS.copy(), DEFAULTS.copy()]}
save_json(current_dir / "template_batch.json", batch_data) save_json(current_dir / "template_batch.json", batch_data)