Add atomic writes, magic string constants, unit tests, type hints, and fix navigation
- save_json() now writes to a temp file then uses os.replace() for atomic writes - Replace hardcoded "batch_data", "history_tree", "prompt_history", "sequence_number" strings with constants (KEY_BATCH_DATA, etc.) across all modules - Add 29 unit tests for history_tree, utils, and json_loader - Add type hints to public functions in utils.py, json_loader.py, history_tree.py - Remove ALLOWED_BASE_DIR restriction that blocked navigating outside app CWD - Fix path text input not updating on navigation by using session state key - Add unpin button (❌) for removing pinned folders Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
67
app.py
67
app.py
@@ -5,7 +5,8 @@ from pathlib import Path
|
||||
# --- Import Custom Modules ---
|
||||
from utils import (
|
||||
load_config, save_config, load_snippets, save_snippets,
|
||||
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR
|
||||
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR,
|
||||
KEY_BATCH_DATA, KEY_PROMPT_HISTORY,
|
||||
)
|
||||
from tab_single import render_single_editor
|
||||
from tab_batch import render_batch_processor
|
||||
@@ -47,37 +48,51 @@ with st.sidebar:
|
||||
st.header("📂 Navigator")
|
||||
|
||||
# --- Path Navigator ---
|
||||
new_path = st.text_input("Current Path", value=str(st.session_state.current_dir))
|
||||
# Sync widget key with current_dir so the text input always reflects the actual path
|
||||
if "nav_path_input" not in st.session_state:
|
||||
st.session_state.nav_path_input = str(st.session_state.current_dir)
|
||||
|
||||
new_path = st.text_input("Current Path", key="nav_path_input")
|
||||
if new_path != str(st.session_state.current_dir):
|
||||
p = Path(new_path).resolve()
|
||||
if p.exists() and p.is_dir():
|
||||
# Restrict navigation to the allowed base directory
|
||||
try:
|
||||
p.relative_to(ALLOWED_BASE_DIR)
|
||||
except ValueError:
|
||||
st.error(f"Access denied: path must be under {ALLOWED_BASE_DIR}")
|
||||
else:
|
||||
st.session_state.current_dir = p
|
||||
st.session_state.config['last_dir'] = str(p)
|
||||
st.session_state.current_dir = p
|
||||
st.session_state.config['last_dir'] = str(p)
|
||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||
st.rerun()
|
||||
elif new_path.strip():
|
||||
st.error(f"Path does not exist or is not a directory: {new_path}")
|
||||
|
||||
# --- Favorites System ---
|
||||
pin_col, unpin_col = st.columns(2)
|
||||
with pin_col:
|
||||
if st.button("📌 Pin Folder", use_container_width=True):
|
||||
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
||||
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||
st.rerun()
|
||||
|
||||
# --- Favorites System ---
|
||||
if st.button("📌 Pin Current Folder"):
|
||||
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
||||
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||
favorites = st.session_state.config['favorites']
|
||||
if favorites:
|
||||
fav_selection = st.radio(
|
||||
"Jump to:",
|
||||
["Select..."] + favorites,
|
||||
index=0,
|
||||
label_visibility="collapsed"
|
||||
)
|
||||
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
|
||||
st.session_state.current_dir = Path(fav_selection)
|
||||
st.session_state.nav_path_input = fav_selection
|
||||
st.rerun()
|
||||
|
||||
fav_selection = st.radio(
|
||||
"Jump to:",
|
||||
["Select..."] + st.session_state.config['favorites'],
|
||||
index=0,
|
||||
label_visibility="collapsed"
|
||||
)
|
||||
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
|
||||
st.session_state.current_dir = Path(fav_selection)
|
||||
st.rerun()
|
||||
# Unpin buttons for each favorite
|
||||
for fav in favorites:
|
||||
fc1, fc2 = st.columns([4, 1])
|
||||
fc1.caption(fav)
|
||||
if fc2.button("❌", key=f"unpin_{fav}"):
|
||||
st.session_state.config['favorites'].remove(fav)
|
||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||
st.rerun()
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
@@ -123,7 +138,7 @@ with st.sidebar:
|
||||
if not new_filename.endswith(".json"): new_filename += ".json"
|
||||
path = st.session_state.current_dir / new_filename
|
||||
if is_batch:
|
||||
data = {"batch_data": []}
|
||||
data = {KEY_BATCH_DATA: []}
|
||||
else:
|
||||
data = DEFAULTS.copy()
|
||||
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
|
||||
@@ -163,7 +178,7 @@ if selected_file_name:
|
||||
st.session_state.edit_history_idx = None
|
||||
|
||||
# --- AUTO-SWITCH TAB LOGIC ---
|
||||
is_batch = "batch_data" in data or isinstance(data, list)
|
||||
is_batch = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||
if is_batch:
|
||||
st.session_state.active_tab_name = "🚀 Batch Processor"
|
||||
else:
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
KEY_PROMPT_HISTORY = "prompt_history"
|
||||
|
||||
|
||||
class HistoryTree:
|
||||
def __init__(self, raw_data):
|
||||
self.nodes = raw_data.get("nodes", {})
|
||||
self.branches = raw_data.get("branches", {"main": None})
|
||||
self.head_id = raw_data.get("head_id", None)
|
||||
|
||||
if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes:
|
||||
self._migrate_legacy(raw_data["prompt_history"])
|
||||
def __init__(self, raw_data: dict[str, Any]) -> None:
|
||||
self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {})
|
||||
self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None})
|
||||
self.head_id: str | None = raw_data.get("head_id", None)
|
||||
|
||||
def _migrate_legacy(self, old_list):
|
||||
if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes:
|
||||
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
||||
|
||||
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||
parent = None
|
||||
for item in reversed(old_list):
|
||||
node_id = str(uuid.uuid4())[:8]
|
||||
@@ -22,7 +26,7 @@ class HistoryTree:
|
||||
self.branches["main"] = parent
|
||||
self.head_id = parent
|
||||
|
||||
def commit(self, data, note="Snapshot"):
|
||||
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
|
||||
new_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Cycle detection: walk parent chain from head to verify no cycle
|
||||
@@ -56,17 +60,17 @@ class HistoryTree:
|
||||
self.head_id = new_id
|
||||
return new_id
|
||||
|
||||
def checkout(self, node_id):
|
||||
def checkout(self, node_id: str) -> dict[str, Any] | None:
|
||||
if node_id in self.nodes:
|
||||
self.head_id = node_id
|
||||
return self.nodes[node_id]["data"]
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
||||
|
||||
# --- UPDATED GRAPH GENERATOR ---
|
||||
def generate_graph(self, direction="LR"):
|
||||
def generate_graph(self, direction: str = "LR") -> str:
|
||||
"""
|
||||
Generates Graphviz source.
|
||||
direction: "LR" (Horizontal) or "TB" (Vertical)
|
||||
|
||||
@@ -1,32 +1,36 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def to_float(val):
|
||||
KEY_BATCH_DATA = "batch_data"
|
||||
|
||||
|
||||
def to_float(val: Any) -> float:
|
||||
try:
|
||||
return float(val)
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
def to_int(val):
|
||||
def to_int(val: Any) -> int:
|
||||
try:
|
||||
return int(float(val))
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
def get_batch_item(data, sequence_number):
|
||||
def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]:
|
||||
"""Resolve batch item by sequence_number, clamping to valid range."""
|
||||
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
||||
idx = max(0, min(sequence_number - 1, len(data["batch_data"]) - 1))
|
||||
if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0:
|
||||
idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1))
|
||||
if sequence_number - 1 != idx:
|
||||
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data['batch_data'])}), clamped to {idx + 1}")
|
||||
return data["batch_data"][idx]
|
||||
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data[KEY_BATCH_DATA])}), clamped to {idx + 1}")
|
||||
return data[KEY_BATCH_DATA][idx]
|
||||
return data
|
||||
|
||||
# --- Shared Helper ---
|
||||
def read_json_data(json_path):
|
||||
def read_json_data(json_path: str) -> dict[str, Any]:
|
||||
if not os.path.exists(json_path):
|
||||
logger.warning(f"File not found at {json_path}")
|
||||
return {}
|
||||
|
||||
80
tab_batch.py
80
tab_batch.py
@@ -1,8 +1,8 @@
|
||||
import streamlit as st
|
||||
import random
|
||||
import copy
|
||||
from utils import DEFAULTS, save_json, load_json
|
||||
from history_tree import HistoryTree
|
||||
from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
|
||||
from history_tree import HistoryTree
|
||||
|
||||
def create_batch_callback(original_filename, current_data, current_dir):
|
||||
new_name = f"batch_{original_filename}"
|
||||
@@ -13,15 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir):
|
||||
return
|
||||
|
||||
first_item = current_data.copy()
|
||||
if "prompt_history" in first_item: del first_item["prompt_history"]
|
||||
if "history_tree" in first_item: del first_item["history_tree"]
|
||||
|
||||
first_item["sequence_number"] = 1
|
||||
|
||||
if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY]
|
||||
if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE]
|
||||
|
||||
first_item[KEY_SEQUENCE_NUMBER] = 1
|
||||
|
||||
new_data = {
|
||||
"batch_data": [first_item],
|
||||
"history_tree": {},
|
||||
"prompt_history": []
|
||||
KEY_BATCH_DATA: [first_item],
|
||||
KEY_HISTORY_TREE: {},
|
||||
KEY_PROMPT_HISTORY: []
|
||||
}
|
||||
|
||||
save_json(new_path, new_data)
|
||||
@@ -30,7 +30,7 @@ def create_batch_callback(original_filename, current_data, current_dir):
|
||||
|
||||
|
||||
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
||||
is_batch_file = "batch_data" in data or isinstance(data, list)
|
||||
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||
|
||||
if not is_batch_file:
|
||||
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
||||
@@ -40,7 +40,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
||||
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
||||
|
||||
batch_list = data.get("batch_data", [])
|
||||
batch_list = data.get(KEY_BATCH_DATA, [])
|
||||
|
||||
# --- ADD NEW SEQUENCE AREA ---
|
||||
st.subheader("Add New Sequence")
|
||||
@@ -53,7 +53,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
src_data, _ = load_json(current_dir / src_name)
|
||||
|
||||
with ac2:
|
||||
src_hist = src_data.get("prompt_history", [])
|
||||
src_hist = src_data.get(KEY_PROMPT_HISTORY, [])
|
||||
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
||||
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
|
||||
|
||||
@@ -62,14 +62,14 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
def add_sequence(new_item):
|
||||
max_seq = 0
|
||||
for s in batch_list:
|
||||
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
||||
new_item["sequence_number"] = max_seq + 1
|
||||
|
||||
for k in ["prompt_history", "history_tree", "note", "loras"]:
|
||||
if KEY_SEQUENCE_NUMBER in s: max_seq = max(max_seq, int(s[KEY_SEQUENCE_NUMBER]))
|
||||
new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1
|
||||
|
||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]:
|
||||
if k in new_item: del new_item[k]
|
||||
|
||||
batch_list.append(new_item)
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
st.rerun()
|
||||
@@ -79,7 +79,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
|
||||
if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"):
|
||||
item = DEFAULTS.copy()
|
||||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||||
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
|
||||
item.update(flat)
|
||||
add_sequence(item)
|
||||
|
||||
@@ -107,7 +107,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
||||
standard_keys = {
|
||||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||
"camera", "flf", "sequence_number"
|
||||
"camera", "flf", KEY_SEQUENCE_NUMBER
|
||||
}
|
||||
standard_keys.update(lora_keys)
|
||||
standard_keys.update([
|
||||
@@ -116,7 +116,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
])
|
||||
|
||||
for i, seq in enumerate(batch_list):
|
||||
seq_num = seq.get("sequence_number", i+1)
|
||||
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1)
|
||||
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
||||
|
||||
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
|
||||
@@ -127,13 +127,13 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
with act_c1:
|
||||
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
||||
item = DEFAULTS.copy()
|
||||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||||
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
|
||||
item.update(flat)
|
||||
item["sequence_number"] = seq_num
|
||||
for k in ["prompt_history", "history_tree"]:
|
||||
item[KEY_SEQUENCE_NUMBER] = seq_num
|
||||
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]:
|
||||
if k in item: del item[k]
|
||||
batch_list[i] = item
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
st.toast("Copied!", icon="📥")
|
||||
@@ -145,10 +145,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
||||
new_seq = seq.copy()
|
||||
max_sn = 0
|
||||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||||
new_seq["sequence_number"] = max_sn + 1
|
||||
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
||||
batch_list.insert(i + 1, new_seq)
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
st.toast("Cloned to Next!", icon="👯")
|
||||
@@ -157,10 +157,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
||||
new_seq = seq.copy()
|
||||
max_sn = 0
|
||||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||||
new_seq["sequence_number"] = max_sn + 1
|
||||
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
||||
batch_list.append(new_seq)
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
st.toast("Cloned to End!", icon="⏬")
|
||||
@@ -170,9 +170,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
with act_c3:
|
||||
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
||||
single_data = seq.copy()
|
||||
single_data["prompt_history"] = data.get("prompt_history", [])
|
||||
single_data["history_tree"] = data.get("history_tree", {})
|
||||
if "sequence_number" in single_data: del single_data["sequence_number"]
|
||||
single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, [])
|
||||
single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {})
|
||||
if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER]
|
||||
save_json(file_path, single_data)
|
||||
st.toast("Converted to Single!", icon="✅")
|
||||
st.rerun()
|
||||
@@ -181,7 +181,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
with act_c4:
|
||||
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
||||
batch_list.pop(i)
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
save_json(file_path, data)
|
||||
st.rerun()
|
||||
|
||||
@@ -194,7 +194,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
||||
|
||||
with c2:
|
||||
seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
|
||||
seq[KEY_SEQUENCE_NUMBER] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
|
||||
|
||||
s_row1, s_row2 = st.columns([3, 1])
|
||||
seed_key = f"{prefix}_seed"
|
||||
@@ -320,17 +320,17 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
||||
|
||||
with col_save:
|
||||
if st.button("💾 Save & Snap", use_container_width=True):
|
||||
data["batch_data"] = batch_list
|
||||
data[KEY_BATCH_DATA] = batch_list
|
||||
|
||||
tree_data = data.get("history_tree", {})
|
||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||
htree = HistoryTree(tree_data)
|
||||
|
||||
snapshot_payload = copy.deepcopy(data)
|
||||
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
|
||||
if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE]
|
||||
|
||||
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
||||
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
|
||||
if 'restored_indicator' in st.session_state:
|
||||
|
||||
14
tab_raw.py
14
tab_raw.py
@@ -1,7 +1,7 @@
|
||||
import streamlit as st
|
||||
import json
|
||||
import copy
|
||||
from utils import save_json, get_file_mtime
|
||||
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
||||
|
||||
def render_raw_editor(data, file_path):
|
||||
st.subheader(f"💻 Raw Editor: {file_path.name}")
|
||||
@@ -20,8 +20,8 @@ def render_raw_editor(data, file_path):
|
||||
if hide_history:
|
||||
display_data = copy.deepcopy(data)
|
||||
# Safely remove heavy keys for the view only
|
||||
if "history_tree" in display_data: del display_data["history_tree"]
|
||||
if "prompt_history" in display_data: del display_data["prompt_history"]
|
||||
if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE]
|
||||
if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY]
|
||||
else:
|
||||
display_data = data
|
||||
|
||||
@@ -51,10 +51,10 @@ def render_raw_editor(data, file_path):
|
||||
|
||||
# 2. If we were in Safe Mode, we must merge the hidden history back in
|
||||
if hide_history:
|
||||
if "history_tree" in data:
|
||||
input_data["history_tree"] = data["history_tree"]
|
||||
if "prompt_history" in data:
|
||||
input_data["prompt_history"] = data["prompt_history"]
|
||||
if KEY_HISTORY_TREE in data:
|
||||
input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
|
||||
if KEY_PROMPT_HISTORY in data:
|
||||
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
||||
|
||||
# 3. Save to Disk
|
||||
save_json(file_path, input_data)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import streamlit as st
|
||||
import random
|
||||
from utils import DEFAULTS, save_json, get_file_mtime
|
||||
from utils import DEFAULTS, save_json, get_file_mtime, KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
|
||||
|
||||
def render_single_editor(data, file_path):
|
||||
is_batch_file = "batch_data" in data or isinstance(data, list)
|
||||
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||
|
||||
if is_batch_file:
|
||||
st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
|
||||
@@ -63,7 +63,7 @@ def render_single_editor(data, file_path):
|
||||
# Explicitly track standard setting keys to exclude them from custom list
|
||||
standard_keys = {
|
||||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token",
|
||||
"camera", "flf", KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, "ui_reset_token",
|
||||
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
|
||||
}
|
||||
standard_keys.update(lora_keys)
|
||||
@@ -169,8 +169,8 @@ def render_single_editor(data, file_path):
|
||||
archive_note = st.text_input("Archive Note")
|
||||
if st.button("📦 Snapshot to History", use_container_width=True):
|
||||
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
|
||||
if "prompt_history" not in data: data["prompt_history"] = []
|
||||
data["prompt_history"].insert(0, entry)
|
||||
if KEY_PROMPT_HISTORY not in data: data[KEY_PROMPT_HISTORY] = []
|
||||
data[KEY_PROMPT_HISTORY].insert(0, entry)
|
||||
data.update(entry)
|
||||
save_json(file_path, data)
|
||||
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||
@@ -181,7 +181,7 @@ def render_single_editor(data, file_path):
|
||||
# --- FULL HISTORY PANEL ---
|
||||
st.markdown("---")
|
||||
st.subheader("History")
|
||||
history = data.get("prompt_history", [])
|
||||
history = data.get(KEY_PROMPT_HISTORY, [])
|
||||
|
||||
if not history:
|
||||
st.caption("No history yet.")
|
||||
|
||||
@@ -4,10 +4,10 @@ import json
|
||||
import graphviz
|
||||
import time
|
||||
from history_tree import HistoryTree
|
||||
from utils import save_json
|
||||
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
|
||||
def render_timeline_tab(data, file_path):
|
||||
tree_data = data.get("history_tree", {})
|
||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||
if not tree_data:
|
||||
st.info("No history timeline exists. Make some changes in the Editor first!")
|
||||
return
|
||||
@@ -61,13 +61,13 @@ def render_timeline_tab(data, file_path):
|
||||
if not is_head:
|
||||
if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"):
|
||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||
if "batch_data" not in n["data"] and "batch_data" in data:
|
||||
del data["batch_data"]
|
||||
if KEY_BATCH_DATA not in n["data"] and KEY_BATCH_DATA in data:
|
||||
del data[KEY_BATCH_DATA]
|
||||
# -------------------------------------------------------------
|
||||
|
||||
data.update(n["data"])
|
||||
htree.head_id = n['id']
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
label = f"{n.get('note')} ({n['id'][:4]})"
|
||||
@@ -109,13 +109,13 @@ def render_timeline_tab(data, file_path):
|
||||
st.write(""); st.write("")
|
||||
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
|
||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||
if "batch_data" not in node_data and "batch_data" in data:
|
||||
del data["batch_data"]
|
||||
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
|
||||
del data[KEY_BATCH_DATA]
|
||||
# -------------------------------------------------------------
|
||||
|
||||
data.update(node_data)
|
||||
htree.head_id = selected_node['id']
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
st.session_state.ui_reset_token += 1
|
||||
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
|
||||
@@ -128,7 +128,7 @@ def render_timeline_tab(data, file_path):
|
||||
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
|
||||
if rn_col2.button("Update Label"):
|
||||
selected_node["note"] = new_label
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
st.rerun()
|
||||
|
||||
@@ -152,7 +152,7 @@ def render_timeline_tab(data, file_path):
|
||||
htree.head_id = fallback["id"]
|
||||
else:
|
||||
htree.head_id = None
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
st.toast("Node Deleted", icon="🗑️")
|
||||
st.rerun()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import streamlit as st
|
||||
import json
|
||||
from history_tree import HistoryTree
|
||||
from utils import save_json
|
||||
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||
|
||||
try:
|
||||
from streamlit_agraph import agraph, Node, Edge, Config
|
||||
@@ -13,7 +13,7 @@ def render_timeline_wip(data, file_path):
|
||||
if not _HAS_AGRAPH:
|
||||
st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`")
|
||||
return
|
||||
tree_data = data.get("history_tree", {})
|
||||
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||
if not tree_data:
|
||||
st.info("No history timeline exists.")
|
||||
return
|
||||
@@ -108,14 +108,14 @@ def render_timeline_wip(data, file_path):
|
||||
st.write(""); st.write("")
|
||||
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
|
||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||
if "batch_data" not in node_data and "batch_data" in data:
|
||||
del data["batch_data"]
|
||||
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
|
||||
del data[KEY_BATCH_DATA]
|
||||
# -------------------------------------------------------------
|
||||
|
||||
data.update(node_data)
|
||||
htree.head_id = target_node_id
|
||||
|
||||
data["history_tree"] = htree.to_dict()
|
||||
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||
save_json(file_path, data)
|
||||
|
||||
st.session_state.ui_reset_token += 1
|
||||
@@ -174,7 +174,7 @@ def render_timeline_wip(data, file_path):
|
||||
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
|
||||
|
||||
# --- DETECT BATCH VS SINGLE ---
|
||||
batch_list = node_data.get("batch_data", [])
|
||||
batch_list = node_data.get(KEY_BATCH_DATA, [])
|
||||
|
||||
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
|
||||
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
5
tests/conftest.py
Normal file
5
tests/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to sys.path so tests can import project modules
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
1
tests/pytest.ini
Normal file
1
tests/pytest.ini
Normal file
@@ -0,0 +1 @@
|
||||
[pytest]
|
||||
67
tests/test_history_tree.py
Normal file
67
tests/test_history_tree.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from history_tree import HistoryTree
|
||||
|
||||
|
||||
def test_commit_creates_node_with_correct_parent():
|
||||
tree = HistoryTree({})
|
||||
id1 = tree.commit({"a": 1}, note="first")
|
||||
id2 = tree.commit({"b": 2}, note="second")
|
||||
|
||||
assert tree.nodes[id1]["parent"] is None
|
||||
assert tree.nodes[id2]["parent"] == id1
|
||||
|
||||
|
||||
def test_checkout_returns_correct_data():
|
||||
tree = HistoryTree({})
|
||||
id1 = tree.commit({"val": 42}, note="snap")
|
||||
result = tree.checkout(id1)
|
||||
assert result == {"val": 42}
|
||||
|
||||
|
||||
def test_checkout_nonexistent_returns_none():
|
||||
tree = HistoryTree({})
|
||||
assert tree.checkout("nonexistent") is None
|
||||
|
||||
|
||||
def test_cycle_detection_raises():
|
||||
tree = HistoryTree({})
|
||||
id1 = tree.commit({"a": 1})
|
||||
# Manually introduce a cycle
|
||||
tree.nodes[id1]["parent"] = id1
|
||||
with pytest.raises(ValueError, match="Cycle detected"):
|
||||
tree.commit({"b": 2})
|
||||
|
||||
|
||||
def test_branch_creation_on_detached_head():
|
||||
tree = HistoryTree({})
|
||||
id1 = tree.commit({"a": 1})
|
||||
id2 = tree.commit({"b": 2})
|
||||
# Detach head by checking out a non-tip node
|
||||
tree.checkout(id1)
|
||||
# head_id is now id1, which is no longer a branch tip (main points to id2)
|
||||
id3 = tree.commit({"c": 3})
|
||||
# A new branch should have been created
|
||||
assert len(tree.branches) == 2
|
||||
assert tree.nodes[id3]["parent"] == id1
|
||||
|
||||
|
||||
def test_legacy_migration():
|
||||
legacy = {
|
||||
"prompt_history": [
|
||||
{"note": "Entry A", "seed": 1},
|
||||
{"note": "Entry B", "seed": 2},
|
||||
]
|
||||
}
|
||||
tree = HistoryTree(legacy)
|
||||
assert len(tree.nodes) == 2
|
||||
assert tree.head_id is not None
|
||||
assert tree.branches["main"] == tree.head_id
|
||||
|
||||
|
||||
def test_to_dict_roundtrip():
|
||||
tree = HistoryTree({})
|
||||
tree.commit({"x": 1}, note="test")
|
||||
d = tree.to_dict()
|
||||
tree2 = HistoryTree(d)
|
||||
assert tree2.head_id == tree.head_id
|
||||
assert tree2.nodes == tree.nodes
|
||||
68
tests/test_json_loader.py
Normal file
68
tests/test_json_loader.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from json_loader import to_float, to_int, get_batch_item, read_json_data
|
||||
|
||||
|
||||
class TestToFloat:
|
||||
def test_valid(self):
|
||||
assert to_float("3.14") == 3.14
|
||||
assert to_float(5) == 5.0
|
||||
|
||||
def test_invalid(self):
|
||||
assert to_float("abc") == 0.0
|
||||
|
||||
def test_none(self):
|
||||
assert to_float(None) == 0.0
|
||||
|
||||
|
||||
class TestToInt:
|
||||
def test_valid(self):
|
||||
assert to_int("7") == 7
|
||||
assert to_int(3.9) == 3
|
||||
|
||||
def test_invalid(self):
|
||||
assert to_int("xyz") == 0
|
||||
|
||||
def test_none(self):
|
||||
assert to_int(None) == 0
|
||||
|
||||
|
||||
class TestGetBatchItem:
|
||||
def test_valid_index(self):
|
||||
data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]}
|
||||
assert get_batch_item(data, 2) == {"a": 2}
|
||||
|
||||
def test_clamp_high(self):
|
||||
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
||||
assert get_batch_item(data, 99) == {"a": 2}
|
||||
|
||||
def test_clamp_low(self):
|
||||
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
||||
assert get_batch_item(data, 0) == {"a": 1}
|
||||
|
||||
def test_no_batch_data(self):
|
||||
data = {"key": "val"}
|
||||
assert get_batch_item(data, 1) == data
|
||||
|
||||
|
||||
class TestReadJsonData:
|
||||
def test_missing_file(self, tmp_path):
|
||||
assert read_json_data(str(tmp_path / "nope.json")) == {}
|
||||
|
||||
def test_invalid_json(self, tmp_path):
|
||||
p = tmp_path / "bad.json"
|
||||
p.write_text("{broken")
|
||||
assert read_json_data(str(p)) == {}
|
||||
|
||||
def test_non_dict_json(self, tmp_path):
|
||||
p = tmp_path / "list.json"
|
||||
p.write_text(json.dumps([1, 2, 3]))
|
||||
assert read_json_data(str(p)) == {}
|
||||
|
||||
def test_valid(self, tmp_path):
|
||||
p = tmp_path / "ok.json"
|
||||
p.write_text(json.dumps({"key": "val"}))
|
||||
assert read_json_data(str(p)) == {"key": "val"}
|
||||
68
tests/test_utils.py
Normal file
68
tests/test_utils.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Mock streamlit before importing utils
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
sys.modules.setdefault("streamlit", MagicMock())
|
||||
|
||||
from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS
|
||||
|
||||
|
||||
def test_load_json_valid(tmp_path):
|
||||
p = tmp_path / "test.json"
|
||||
data = {"key": "value"}
|
||||
p.write_text(json.dumps(data))
|
||||
result, mtime = load_json(p)
|
||||
assert result == data
|
||||
assert mtime > 0
|
||||
|
||||
|
||||
def test_load_json_missing(tmp_path):
|
||||
p = tmp_path / "nope.json"
|
||||
result, mtime = load_json(p)
|
||||
assert result == DEFAULTS.copy()
|
||||
assert mtime == 0
|
||||
|
||||
|
||||
def test_load_json_invalid(tmp_path):
|
||||
p = tmp_path / "bad.json"
|
||||
p.write_text("{not valid json")
|
||||
result, mtime = load_json(p)
|
||||
assert result == DEFAULTS.copy()
|
||||
assert mtime == 0
|
||||
|
||||
|
||||
def test_save_json_atomic(tmp_path):
|
||||
p = tmp_path / "out.json"
|
||||
data = {"hello": "world"}
|
||||
save_json(p, data)
|
||||
assert p.exists()
|
||||
assert not p.with_suffix(".json.tmp").exists()
|
||||
assert json.loads(p.read_text()) == data
|
||||
|
||||
|
||||
def test_save_json_overwrites(tmp_path):
|
||||
p = tmp_path / "out.json"
|
||||
save_json(p, {"a": 1})
|
||||
save_json(p, {"b": 2})
|
||||
assert json.loads(p.read_text()) == {"b": 2}
|
||||
|
||||
|
||||
def test_get_file_mtime_existing(tmp_path):
|
||||
p = tmp_path / "f.txt"
|
||||
p.write_text("x")
|
||||
assert get_file_mtime(p) > 0
|
||||
|
||||
|
||||
def test_get_file_mtime_missing(tmp_path):
|
||||
assert get_file_mtime(tmp_path / "missing.txt") == 0
|
||||
|
||||
|
||||
def test_allowed_base_dir_is_set():
|
||||
assert ALLOWED_BASE_DIR is not None
|
||||
assert isinstance(ALLOWED_BASE_DIR, Path)
|
||||
30
utils.py
30
utils.py
@@ -1,9 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# --- Magic String Keys ---
|
||||
KEY_BATCH_DATA = "batch_data"
|
||||
KEY_HISTORY_TREE = "history_tree"
|
||||
KEY_PROMPT_HISTORY = "prompt_history"
|
||||
KEY_SEQUENCE_NUMBER = "sequence_number"
|
||||
|
||||
# Configure logging for the application
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -52,8 +61,8 @@ DEFAULTS = {
|
||||
CONFIG_FILE = Path(".editor_config.json")
|
||||
SNIPPETS_FILE = Path(".editor_snippets.json")
|
||||
|
||||
# Restrict directory navigation to this base path (resolve symlinks)
|
||||
ALLOWED_BASE_DIR = Path.cwd().resolve()
|
||||
# No restriction on directory navigation
|
||||
ALLOWED_BASE_DIR = Path("/").resolve()
|
||||
|
||||
def load_config():
|
||||
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
|
||||
@@ -96,7 +105,7 @@ def save_snippets(snippets):
|
||||
with open(SNIPPETS_FILE, 'w') as f:
|
||||
json.dump(snippets, f, indent=4)
|
||||
|
||||
def load_json(path):
|
||||
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return DEFAULTS.copy(), 0
|
||||
@@ -108,20 +117,23 @@ def load_json(path):
|
||||
st.error(f"Error loading JSON: {e}")
|
||||
return DEFAULTS.copy(), 0
|
||||
|
||||
def save_json(path, data):
|
||||
with open(path, 'w') as f:
|
||||
def save_json(path: str | Path, data: dict[str, Any]) -> None:
|
||||
path = Path(path)
|
||||
tmp = path.with_suffix('.json.tmp')
|
||||
with open(tmp, 'w') as f:
|
||||
json.dump(data, f, indent=4)
|
||||
os.replace(tmp, path)
|
||||
|
||||
def get_file_mtime(path):
|
||||
def get_file_mtime(path: str | Path) -> float:
|
||||
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
||||
path = Path(path)
|
||||
if path.exists():
|
||||
return path.stat().st_mtime
|
||||
return 0
|
||||
|
||||
def generate_templates(current_dir):
|
||||
def generate_templates(current_dir: Path) -> None:
|
||||
"""Creates dummy template files if folder is empty."""
|
||||
save_json(current_dir / "template_i2v.json", DEFAULTS)
|
||||
|
||||
batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]}
|
||||
|
||||
batch_data = {KEY_BATCH_DATA: [DEFAULTS.copy(), DEFAULTS.copy()]}
|
||||
save_json(current_dir / "template_batch.json", batch_data)
|
||||
|
||||
Reference in New Issue
Block a user