Add atomic writes, magic string constants, unit tests, type hints, and fix navigation
- save_json() now writes to a temp file then uses os.replace() for atomic writes - Replace hardcoded "batch_data", "history_tree", "prompt_history", "sequence_number" strings with constants (KEY_BATCH_DATA, etc.) across all modules - Add 29 unit tests for history_tree, utils, and json_loader - Add type hints to public functions in utils.py, json_loader.py, history_tree.py - Remove ALLOWED_BASE_DIR restriction that blocked navigating outside app CWD - Fix path text input not updating on navigation by using session state key - Add unpin button (❌) for removing pinned folders Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
39
app.py
39
app.py
@@ -5,7 +5,8 @@ from pathlib import Path
|
|||||||
# --- Import Custom Modules ---
|
# --- Import Custom Modules ---
|
||||||
from utils import (
|
from utils import (
|
||||||
load_config, save_config, load_snippets, save_snippets,
|
load_config, save_config, load_snippets, save_snippets,
|
||||||
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR
|
load_json, save_json, generate_templates, DEFAULTS, ALLOWED_BASE_DIR,
|
||||||
|
KEY_BATCH_DATA, KEY_PROMPT_HISTORY,
|
||||||
)
|
)
|
||||||
from tab_single import render_single_editor
|
from tab_single import render_single_editor
|
||||||
from tab_batch import render_batch_processor
|
from tab_batch import render_batch_processor
|
||||||
@@ -47,36 +48,50 @@ with st.sidebar:
|
|||||||
st.header("📂 Navigator")
|
st.header("📂 Navigator")
|
||||||
|
|
||||||
# --- Path Navigator ---
|
# --- Path Navigator ---
|
||||||
new_path = st.text_input("Current Path", value=str(st.session_state.current_dir))
|
# Sync widget key with current_dir so the text input always reflects the actual path
|
||||||
|
if "nav_path_input" not in st.session_state:
|
||||||
|
st.session_state.nav_path_input = str(st.session_state.current_dir)
|
||||||
|
|
||||||
|
new_path = st.text_input("Current Path", key="nav_path_input")
|
||||||
if new_path != str(st.session_state.current_dir):
|
if new_path != str(st.session_state.current_dir):
|
||||||
p = Path(new_path).resolve()
|
p = Path(new_path).resolve()
|
||||||
if p.exists() and p.is_dir():
|
if p.exists() and p.is_dir():
|
||||||
# Restrict navigation to the allowed base directory
|
|
||||||
try:
|
|
||||||
p.relative_to(ALLOWED_BASE_DIR)
|
|
||||||
except ValueError:
|
|
||||||
st.error(f"Access denied: path must be under {ALLOWED_BASE_DIR}")
|
|
||||||
else:
|
|
||||||
st.session_state.current_dir = p
|
st.session_state.current_dir = p
|
||||||
st.session_state.config['last_dir'] = str(p)
|
st.session_state.config['last_dir'] = str(p)
|
||||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
elif new_path.strip():
|
||||||
|
st.error(f"Path does not exist or is not a directory: {new_path}")
|
||||||
|
|
||||||
# --- Favorites System ---
|
# --- Favorites System ---
|
||||||
if st.button("📌 Pin Current Folder"):
|
pin_col, unpin_col = st.columns(2)
|
||||||
|
with pin_col:
|
||||||
|
if st.button("📌 Pin Folder", use_container_width=True):
|
||||||
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
if str(st.session_state.current_dir) not in st.session_state.config['favorites']:
|
||||||
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
st.session_state.config['favorites'].append(str(st.session_state.current_dir))
|
||||||
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
|
favorites = st.session_state.config['favorites']
|
||||||
|
if favorites:
|
||||||
fav_selection = st.radio(
|
fav_selection = st.radio(
|
||||||
"Jump to:",
|
"Jump to:",
|
||||||
["Select..."] + st.session_state.config['favorites'],
|
["Select..."] + favorites,
|
||||||
index=0,
|
index=0,
|
||||||
label_visibility="collapsed"
|
label_visibility="collapsed"
|
||||||
)
|
)
|
||||||
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
|
if fav_selection != "Select..." and fav_selection != str(st.session_state.current_dir):
|
||||||
st.session_state.current_dir = Path(fav_selection)
|
st.session_state.current_dir = Path(fav_selection)
|
||||||
|
st.session_state.nav_path_input = fav_selection
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# Unpin buttons for each favorite
|
||||||
|
for fav in favorites:
|
||||||
|
fc1, fc2 = st.columns([4, 1])
|
||||||
|
fc1.caption(fav)
|
||||||
|
if fc2.button("❌", key=f"unpin_{fav}"):
|
||||||
|
st.session_state.config['favorites'].remove(fav)
|
||||||
|
save_config(st.session_state.current_dir, st.session_state.config['favorites'])
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
@@ -123,7 +138,7 @@ with st.sidebar:
|
|||||||
if not new_filename.endswith(".json"): new_filename += ".json"
|
if not new_filename.endswith(".json"): new_filename += ".json"
|
||||||
path = st.session_state.current_dir / new_filename
|
path = st.session_state.current_dir / new_filename
|
||||||
if is_batch:
|
if is_batch:
|
||||||
data = {"batch_data": []}
|
data = {KEY_BATCH_DATA: []}
|
||||||
else:
|
else:
|
||||||
data = DEFAULTS.copy()
|
data = DEFAULTS.copy()
|
||||||
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
|
if "vace" in new_filename: data.update({"frame_to_skip": 81, "vace schedule": 1, "video file path": ""})
|
||||||
@@ -163,7 +178,7 @@ if selected_file_name:
|
|||||||
st.session_state.edit_history_idx = None
|
st.session_state.edit_history_idx = None
|
||||||
|
|
||||||
# --- AUTO-SWITCH TAB LOGIC ---
|
# --- AUTO-SWITCH TAB LOGIC ---
|
||||||
is_batch = "batch_data" in data or isinstance(data, list)
|
is_batch = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||||
if is_batch:
|
if is_batch:
|
||||||
st.session_state.active_tab_name = "🚀 Batch Processor"
|
st.session_state.active_tab_name = "🚀 Batch Processor"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
KEY_PROMPT_HISTORY = "prompt_history"
|
||||||
|
|
||||||
|
|
||||||
class HistoryTree:
|
class HistoryTree:
|
||||||
def __init__(self, raw_data):
|
def __init__(self, raw_data: dict[str, Any]) -> None:
|
||||||
self.nodes = raw_data.get("nodes", {})
|
self.nodes: dict[str, dict[str, Any]] = raw_data.get("nodes", {})
|
||||||
self.branches = raw_data.get("branches", {"main": None})
|
self.branches: dict[str, str | None] = raw_data.get("branches", {"main": None})
|
||||||
self.head_id = raw_data.get("head_id", None)
|
self.head_id: str | None = raw_data.get("head_id", None)
|
||||||
|
|
||||||
if "prompt_history" in raw_data and isinstance(raw_data["prompt_history"], list) and not self.nodes:
|
if KEY_PROMPT_HISTORY in raw_data and isinstance(raw_data[KEY_PROMPT_HISTORY], list) and not self.nodes:
|
||||||
self._migrate_legacy(raw_data["prompt_history"])
|
self._migrate_legacy(raw_data[KEY_PROMPT_HISTORY])
|
||||||
|
|
||||||
def _migrate_legacy(self, old_list):
|
def _migrate_legacy(self, old_list: list[dict[str, Any]]) -> None:
|
||||||
parent = None
|
parent = None
|
||||||
for item in reversed(old_list):
|
for item in reversed(old_list):
|
||||||
node_id = str(uuid.uuid4())[:8]
|
node_id = str(uuid.uuid4())[:8]
|
||||||
@@ -22,7 +26,7 @@ class HistoryTree:
|
|||||||
self.branches["main"] = parent
|
self.branches["main"] = parent
|
||||||
self.head_id = parent
|
self.head_id = parent
|
||||||
|
|
||||||
def commit(self, data, note="Snapshot"):
|
def commit(self, data: dict[str, Any], note: str = "Snapshot") -> str:
|
||||||
new_id = str(uuid.uuid4())[:8]
|
new_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# Cycle detection: walk parent chain from head to verify no cycle
|
# Cycle detection: walk parent chain from head to verify no cycle
|
||||||
@@ -56,17 +60,17 @@ class HistoryTree:
|
|||||||
self.head_id = new_id
|
self.head_id = new_id
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
def checkout(self, node_id):
|
def checkout(self, node_id: str) -> dict[str, Any] | None:
|
||||||
if node_id in self.nodes:
|
if node_id in self.nodes:
|
||||||
self.head_id = node_id
|
self.head_id = node_id
|
||||||
return self.nodes[node_id]["data"]
|
return self.nodes[node_id]["data"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
return {"nodes": self.nodes, "branches": self.branches, "head_id": self.head_id}
|
||||||
|
|
||||||
# --- UPDATED GRAPH GENERATOR ---
|
# --- UPDATED GRAPH GENERATOR ---
|
||||||
def generate_graph(self, direction="LR"):
|
def generate_graph(self, direction: str = "LR") -> str:
|
||||||
"""
|
"""
|
||||||
Generates Graphviz source.
|
Generates Graphviz source.
|
||||||
direction: "LR" (Horizontal) or "TB" (Vertical)
|
direction: "LR" (Horizontal) or "TB" (Vertical)
|
||||||
|
|||||||
@@ -1,32 +1,36 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def to_float(val):
|
KEY_BATCH_DATA = "batch_data"
|
||||||
|
|
||||||
|
|
||||||
|
def to_float(val: Any) -> float:
|
||||||
try:
|
try:
|
||||||
return float(val)
|
return float(val)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def to_int(val):
|
def to_int(val: Any) -> int:
|
||||||
try:
|
try:
|
||||||
return int(float(val))
|
return int(float(val))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_batch_item(data, sequence_number):
|
def get_batch_item(data: dict[str, Any], sequence_number: int) -> dict[str, Any]:
|
||||||
"""Resolve batch item by sequence_number, clamping to valid range."""
|
"""Resolve batch item by sequence_number, clamping to valid range."""
|
||||||
if "batch_data" in data and isinstance(data["batch_data"], list) and len(data["batch_data"]) > 0:
|
if KEY_BATCH_DATA in data and isinstance(data[KEY_BATCH_DATA], list) and len(data[KEY_BATCH_DATA]) > 0:
|
||||||
idx = max(0, min(sequence_number - 1, len(data["batch_data"]) - 1))
|
idx = max(0, min(sequence_number - 1, len(data[KEY_BATCH_DATA]) - 1))
|
||||||
if sequence_number - 1 != idx:
|
if sequence_number - 1 != idx:
|
||||||
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data['batch_data'])}), clamped to {idx + 1}")
|
logger.warning(f"Sequence {sequence_number} out of range (1-{len(data[KEY_BATCH_DATA])}), clamped to {idx + 1}")
|
||||||
return data["batch_data"][idx]
|
return data[KEY_BATCH_DATA][idx]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# --- Shared Helper ---
|
# --- Shared Helper ---
|
||||||
def read_json_data(json_path):
|
def read_json_data(json_path: str) -> dict[str, Any]:
|
||||||
if not os.path.exists(json_path):
|
if not os.path.exists(json_path):
|
||||||
logger.warning(f"File not found at {json_path}")
|
logger.warning(f"File not found at {json_path}")
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
72
tab_batch.py
72
tab_batch.py
@@ -1,7 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import random
|
import random
|
||||||
import copy
|
import copy
|
||||||
from utils import DEFAULTS, save_json, load_json
|
from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
|
||||||
from history_tree import HistoryTree
|
from history_tree import HistoryTree
|
||||||
|
|
||||||
def create_batch_callback(original_filename, current_data, current_dir):
|
def create_batch_callback(original_filename, current_data, current_dir):
|
||||||
@@ -13,15 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir):
|
|||||||
return
|
return
|
||||||
|
|
||||||
first_item = current_data.copy()
|
first_item = current_data.copy()
|
||||||
if "prompt_history" in first_item: del first_item["prompt_history"]
|
if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY]
|
||||||
if "history_tree" in first_item: del first_item["history_tree"]
|
if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE]
|
||||||
|
|
||||||
first_item["sequence_number"] = 1
|
first_item[KEY_SEQUENCE_NUMBER] = 1
|
||||||
|
|
||||||
new_data = {
|
new_data = {
|
||||||
"batch_data": [first_item],
|
KEY_BATCH_DATA: [first_item],
|
||||||
"history_tree": {},
|
KEY_HISTORY_TREE: {},
|
||||||
"prompt_history": []
|
KEY_PROMPT_HISTORY: []
|
||||||
}
|
}
|
||||||
|
|
||||||
save_json(new_path, new_data)
|
save_json(new_path, new_data)
|
||||||
@@ -30,7 +30,7 @@ def create_batch_callback(original_filename, current_data, current_dir):
|
|||||||
|
|
||||||
|
|
||||||
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
||||||
is_batch_file = "batch_data" in data or isinstance(data, list)
|
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||||
|
|
||||||
if not is_batch_file:
|
if not is_batch_file:
|
||||||
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
||||||
@@ -40,7 +40,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
||||||
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
||||||
|
|
||||||
batch_list = data.get("batch_data", [])
|
batch_list = data.get(KEY_BATCH_DATA, [])
|
||||||
|
|
||||||
# --- ADD NEW SEQUENCE AREA ---
|
# --- ADD NEW SEQUENCE AREA ---
|
||||||
st.subheader("Add New Sequence")
|
st.subheader("Add New Sequence")
|
||||||
@@ -53,7 +53,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
src_data, _ = load_json(current_dir / src_name)
|
src_data, _ = load_json(current_dir / src_name)
|
||||||
|
|
||||||
with ac2:
|
with ac2:
|
||||||
src_hist = src_data.get("prompt_history", [])
|
src_hist = src_data.get(KEY_PROMPT_HISTORY, [])
|
||||||
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
||||||
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
|
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
|
||||||
|
|
||||||
@@ -62,14 +62,14 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
def add_sequence(new_item):
|
def add_sequence(new_item):
|
||||||
max_seq = 0
|
max_seq = 0
|
||||||
for s in batch_list:
|
for s in batch_list:
|
||||||
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
if KEY_SEQUENCE_NUMBER in s: max_seq = max(max_seq, int(s[KEY_SEQUENCE_NUMBER]))
|
||||||
new_item["sequence_number"] = max_seq + 1
|
new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1
|
||||||
|
|
||||||
for k in ["prompt_history", "history_tree", "note", "loras"]:
|
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]:
|
||||||
if k in new_item: del new_item[k]
|
if k in new_item: del new_item[k]
|
||||||
|
|
||||||
batch_list.append(new_item)
|
batch_list.append(new_item)
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
st.rerun()
|
st.rerun()
|
||||||
@@ -79,7 +79,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
|
|
||||||
if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"):
|
if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"):
|
||||||
item = DEFAULTS.copy()
|
item = DEFAULTS.copy()
|
||||||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
|
||||||
item.update(flat)
|
item.update(flat)
|
||||||
add_sequence(item)
|
add_sequence(item)
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
||||||
standard_keys = {
|
standard_keys = {
|
||||||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||||
"camera", "flf", "sequence_number"
|
"camera", "flf", KEY_SEQUENCE_NUMBER
|
||||||
}
|
}
|
||||||
standard_keys.update(lora_keys)
|
standard_keys.update(lora_keys)
|
||||||
standard_keys.update([
|
standard_keys.update([
|
||||||
@@ -116,7 +116,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
])
|
])
|
||||||
|
|
||||||
for i, seq in enumerate(batch_list):
|
for i, seq in enumerate(batch_list):
|
||||||
seq_num = seq.get("sequence_number", i+1)
|
seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1)
|
||||||
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
||||||
|
|
||||||
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
|
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
|
||||||
@@ -127,13 +127,13 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
with act_c1:
|
with act_c1:
|
||||||
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
||||||
item = DEFAULTS.copy()
|
item = DEFAULTS.copy()
|
||||||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
flat = src_data[KEY_BATCH_DATA][0] if KEY_BATCH_DATA in src_data and src_data[KEY_BATCH_DATA] else src_data
|
||||||
item.update(flat)
|
item.update(flat)
|
||||||
item["sequence_number"] = seq_num
|
item[KEY_SEQUENCE_NUMBER] = seq_num
|
||||||
for k in ["prompt_history", "history_tree"]:
|
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]:
|
||||||
if k in item: del item[k]
|
if k in item: del item[k]
|
||||||
batch_list[i] = item
|
batch_list[i] = item
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
st.toast("Copied!", icon="📥")
|
st.toast("Copied!", icon="📥")
|
||||||
@@ -145,10 +145,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
||||||
new_seq = seq.copy()
|
new_seq = seq.copy()
|
||||||
max_sn = 0
|
max_sn = 0
|
||||||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||||
new_seq["sequence_number"] = max_sn + 1
|
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
||||||
batch_list.insert(i + 1, new_seq)
|
batch_list.insert(i + 1, new_seq)
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
st.toast("Cloned to Next!", icon="👯")
|
st.toast("Cloned to Next!", icon="👯")
|
||||||
@@ -157,10 +157,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
||||||
new_seq = seq.copy()
|
new_seq = seq.copy()
|
||||||
max_sn = 0
|
max_sn = 0
|
||||||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
for s in batch_list: max_sn = max(max_sn, int(s.get(KEY_SEQUENCE_NUMBER, 0)))
|
||||||
new_seq["sequence_number"] = max_sn + 1
|
new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1
|
||||||
batch_list.append(new_seq)
|
batch_list.append(new_seq)
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
st.toast("Cloned to End!", icon="⏬")
|
st.toast("Cloned to End!", icon="⏬")
|
||||||
@@ -170,9 +170,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
with act_c3:
|
with act_c3:
|
||||||
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
||||||
single_data = seq.copy()
|
single_data = seq.copy()
|
||||||
single_data["prompt_history"] = data.get("prompt_history", [])
|
single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, [])
|
||||||
single_data["history_tree"] = data.get("history_tree", {})
|
single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {})
|
||||||
if "sequence_number" in single_data: del single_data["sequence_number"]
|
if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER]
|
||||||
save_json(file_path, single_data)
|
save_json(file_path, single_data)
|
||||||
st.toast("Converted to Single!", icon="✅")
|
st.toast("Converted to Single!", icon="✅")
|
||||||
st.rerun()
|
st.rerun()
|
||||||
@@ -181,7 +181,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
with act_c4:
|
with act_c4:
|
||||||
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
||||||
batch_list.pop(i)
|
batch_list.pop(i)
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
||||||
|
|
||||||
with c2:
|
with c2:
|
||||||
seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
|
seq[KEY_SEQUENCE_NUMBER] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
|
||||||
|
|
||||||
s_row1, s_row2 = st.columns([3, 1])
|
s_row1, s_row2 = st.columns([3, 1])
|
||||||
seed_key = f"{prefix}_seed"
|
seed_key = f"{prefix}_seed"
|
||||||
@@ -320,17 +320,17 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
|
|
||||||
with col_save:
|
with col_save:
|
||||||
if st.button("💾 Save & Snap", use_container_width=True):
|
if st.button("💾 Save & Snap", use_container_width=True):
|
||||||
data["batch_data"] = batch_list
|
data[KEY_BATCH_DATA] = batch_list
|
||||||
|
|
||||||
tree_data = data.get("history_tree", {})
|
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||||
htree = HistoryTree(tree_data)
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
snapshot_payload = copy.deepcopy(data)
|
snapshot_payload = copy.deepcopy(data)
|
||||||
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
|
if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE]
|
||||||
|
|
||||||
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
||||||
|
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
|
||||||
if 'restored_indicator' in st.session_state:
|
if 'restored_indicator' in st.session_state:
|
||||||
|
|||||||
14
tab_raw.py
14
tab_raw.py
@@ -1,7 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import json
|
import json
|
||||||
import copy
|
import copy
|
||||||
from utils import save_json, get_file_mtime
|
from utils import save_json, get_file_mtime, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY
|
||||||
|
|
||||||
def render_raw_editor(data, file_path):
|
def render_raw_editor(data, file_path):
|
||||||
st.subheader(f"💻 Raw Editor: {file_path.name}")
|
st.subheader(f"💻 Raw Editor: {file_path.name}")
|
||||||
@@ -20,8 +20,8 @@ def render_raw_editor(data, file_path):
|
|||||||
if hide_history:
|
if hide_history:
|
||||||
display_data = copy.deepcopy(data)
|
display_data = copy.deepcopy(data)
|
||||||
# Safely remove heavy keys for the view only
|
# Safely remove heavy keys for the view only
|
||||||
if "history_tree" in display_data: del display_data["history_tree"]
|
if KEY_HISTORY_TREE in display_data: del display_data[KEY_HISTORY_TREE]
|
||||||
if "prompt_history" in display_data: del display_data["prompt_history"]
|
if KEY_PROMPT_HISTORY in display_data: del display_data[KEY_PROMPT_HISTORY]
|
||||||
else:
|
else:
|
||||||
display_data = data
|
display_data = data
|
||||||
|
|
||||||
@@ -51,10 +51,10 @@ def render_raw_editor(data, file_path):
|
|||||||
|
|
||||||
# 2. If we were in Safe Mode, we must merge the hidden history back in
|
# 2. If we were in Safe Mode, we must merge the hidden history back in
|
||||||
if hide_history:
|
if hide_history:
|
||||||
if "history_tree" in data:
|
if KEY_HISTORY_TREE in data:
|
||||||
input_data["history_tree"] = data["history_tree"]
|
input_data[KEY_HISTORY_TREE] = data[KEY_HISTORY_TREE]
|
||||||
if "prompt_history" in data:
|
if KEY_PROMPT_HISTORY in data:
|
||||||
input_data["prompt_history"] = data["prompt_history"]
|
input_data[KEY_PROMPT_HISTORY] = data[KEY_PROMPT_HISTORY]
|
||||||
|
|
||||||
# 3. Save to Disk
|
# 3. Save to Disk
|
||||||
save_json(file_path, input_data)
|
save_json(file_path, input_data)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import random
|
import random
|
||||||
from utils import DEFAULTS, save_json, get_file_mtime
|
from utils import DEFAULTS, save_json, get_file_mtime, KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER
|
||||||
|
|
||||||
def render_single_editor(data, file_path):
|
def render_single_editor(data, file_path):
|
||||||
is_batch_file = "batch_data" in data or isinstance(data, list)
|
is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list)
|
||||||
|
|
||||||
if is_batch_file:
|
if is_batch_file:
|
||||||
st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
|
st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
|
||||||
@@ -63,7 +63,7 @@ def render_single_editor(data, file_path):
|
|||||||
# Explicitly track standard setting keys to exclude them from custom list
|
# Explicitly track standard setting keys to exclude them from custom list
|
||||||
standard_keys = {
|
standard_keys = {
|
||||||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||||||
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token",
|
"camera", "flf", KEY_BATCH_DATA, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER, "ui_reset_token",
|
||||||
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
|
"model_name", "vae_name", "steps", "cfg", "denoise", "sampler_name", "scheduler"
|
||||||
}
|
}
|
||||||
standard_keys.update(lora_keys)
|
standard_keys.update(lora_keys)
|
||||||
@@ -169,8 +169,8 @@ def render_single_editor(data, file_path):
|
|||||||
archive_note = st.text_input("Archive Note")
|
archive_note = st.text_input("Archive Note")
|
||||||
if st.button("📦 Snapshot to History", use_container_width=True):
|
if st.button("📦 Snapshot to History", use_container_width=True):
|
||||||
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
|
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
|
||||||
if "prompt_history" not in data: data["prompt_history"] = []
|
if KEY_PROMPT_HISTORY not in data: data[KEY_PROMPT_HISTORY] = []
|
||||||
data["prompt_history"].insert(0, entry)
|
data[KEY_PROMPT_HISTORY].insert(0, entry)
|
||||||
data.update(entry)
|
data.update(entry)
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.last_mtime = get_file_mtime(file_path)
|
st.session_state.last_mtime = get_file_mtime(file_path)
|
||||||
@@ -181,7 +181,7 @@ def render_single_editor(data, file_path):
|
|||||||
# --- FULL HISTORY PANEL ---
|
# --- FULL HISTORY PANEL ---
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
st.subheader("History")
|
st.subheader("History")
|
||||||
history = data.get("prompt_history", [])
|
history = data.get(KEY_PROMPT_HISTORY, [])
|
||||||
|
|
||||||
if not history:
|
if not history:
|
||||||
st.caption("No history yet.")
|
st.caption("No history yet.")
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import json
|
|||||||
import graphviz
|
import graphviz
|
||||||
import time
|
import time
|
||||||
from history_tree import HistoryTree
|
from history_tree import HistoryTree
|
||||||
from utils import save_json
|
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||||
|
|
||||||
def render_timeline_tab(data, file_path):
|
def render_timeline_tab(data, file_path):
|
||||||
tree_data = data.get("history_tree", {})
|
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||||
if not tree_data:
|
if not tree_data:
|
||||||
st.info("No history timeline exists. Make some changes in the Editor first!")
|
st.info("No history timeline exists. Make some changes in the Editor first!")
|
||||||
return
|
return
|
||||||
@@ -61,13 +61,13 @@ def render_timeline_tab(data, file_path):
|
|||||||
if not is_head:
|
if not is_head:
|
||||||
if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"):
|
if st.button("⏪", key=f"log_rst_{n['id']}", help="Restore this version"):
|
||||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||||
if "batch_data" not in n["data"] and "batch_data" in data:
|
if KEY_BATCH_DATA not in n["data"] and KEY_BATCH_DATA in data:
|
||||||
del data["batch_data"]
|
del data[KEY_BATCH_DATA]
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
|
|
||||||
data.update(n["data"])
|
data.update(n["data"])
|
||||||
htree.head_id = n['id']
|
htree.head_id = n['id']
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
label = f"{n.get('note')} ({n['id'][:4]})"
|
label = f"{n.get('note')} ({n['id'][:4]})"
|
||||||
@@ -109,13 +109,13 @@ def render_timeline_tab(data, file_path):
|
|||||||
st.write(""); st.write("")
|
st.write(""); st.write("")
|
||||||
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
|
if st.button("⏪ Restore Version", type="primary", use_container_width=True):
|
||||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||||
if "batch_data" not in node_data and "batch_data" in data:
|
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
|
||||||
del data["batch_data"]
|
del data[KEY_BATCH_DATA]
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
|
|
||||||
data.update(node_data)
|
data.update(node_data)
|
||||||
htree.head_id = selected_node['id']
|
htree.head_id = selected_node['id']
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
|
label = f"{selected_node.get('note')} ({selected_node['id'][:4]})"
|
||||||
@@ -128,7 +128,7 @@ def render_timeline_tab(data, file_path):
|
|||||||
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
|
new_label = rn_col1.text_input("Rename Label", value=selected_node.get("note", ""))
|
||||||
if rn_col2.button("Update Label"):
|
if rn_col2.button("Update Label"):
|
||||||
selected_node["note"] = new_label
|
selected_node["note"] = new_label
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ def render_timeline_tab(data, file_path):
|
|||||||
htree.head_id = fallback["id"]
|
htree.head_id = fallback["id"]
|
||||||
else:
|
else:
|
||||||
htree.head_id = None
|
htree.head_id = None
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.toast("Node Deleted", icon="🗑️")
|
st.toast("Node Deleted", icon="🗑️")
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import json
|
import json
|
||||||
from history_tree import HistoryTree
|
from history_tree import HistoryTree
|
||||||
from utils import save_json
|
from utils import save_json, KEY_BATCH_DATA, KEY_HISTORY_TREE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from streamlit_agraph import agraph, Node, Edge, Config
|
from streamlit_agraph import agraph, Node, Edge, Config
|
||||||
@@ -13,7 +13,7 @@ def render_timeline_wip(data, file_path):
|
|||||||
if not _HAS_AGRAPH:
|
if not _HAS_AGRAPH:
|
||||||
st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`")
|
st.error("The `streamlit-agraph` package is required for this tab. Install it with: `pip install streamlit-agraph`")
|
||||||
return
|
return
|
||||||
tree_data = data.get("history_tree", {})
|
tree_data = data.get(KEY_HISTORY_TREE, {})
|
||||||
if not tree_data:
|
if not tree_data:
|
||||||
st.info("No history timeline exists.")
|
st.info("No history timeline exists.")
|
||||||
return
|
return
|
||||||
@@ -108,14 +108,14 @@ def render_timeline_wip(data, file_path):
|
|||||||
st.write(""); st.write("")
|
st.write(""); st.write("")
|
||||||
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
|
if st.button("⏪ Restore This Version", type="primary", use_container_width=True, key=f"rst_{target_node_id}"):
|
||||||
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
# --- FIX: Cleanup 'batch_data' if restoring a Single File ---
|
||||||
if "batch_data" not in node_data and "batch_data" in data:
|
if KEY_BATCH_DATA not in node_data and KEY_BATCH_DATA in data:
|
||||||
del data["batch_data"]
|
del data[KEY_BATCH_DATA]
|
||||||
# -------------------------------------------------------------
|
# -------------------------------------------------------------
|
||||||
|
|
||||||
data.update(node_data)
|
data.update(node_data)
|
||||||
htree.head_id = target_node_id
|
htree.head_id = target_node_id
|
||||||
|
|
||||||
data["history_tree"] = htree.to_dict()
|
data[KEY_HISTORY_TREE] = htree.to_dict()
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
|
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
@@ -174,7 +174,7 @@ def render_timeline_wip(data, file_path):
|
|||||||
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
|
v3.text_input("Video Path", value=str(item_data.get("video file path", "")), disabled=True, key=f"{prefix}_vid")
|
||||||
|
|
||||||
# --- DETECT BATCH VS SINGLE ---
|
# --- DETECT BATCH VS SINGLE ---
|
||||||
batch_list = node_data.get("batch_data", [])
|
batch_list = node_data.get(KEY_BATCH_DATA, [])
|
||||||
|
|
||||||
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
|
if batch_list and isinstance(batch_list, list) and len(batch_list) > 0:
|
||||||
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")
|
st.info(f"📚 This snapshot contains {len(batch_list)} sequences.")
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
5
tests/conftest.py
Normal file
5
tests/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to sys.path so tests can import project modules
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
1
tests/pytest.ini
Normal file
1
tests/pytest.ini
Normal file
@@ -0,0 +1 @@
|
|||||||
|
[pytest]
|
||||||
67
tests/test_history_tree.py
Normal file
67
tests/test_history_tree.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import pytest
|
||||||
|
from history_tree import HistoryTree
|
||||||
|
|
||||||
|
|
||||||
|
def test_commit_creates_node_with_correct_parent():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
id1 = tree.commit({"a": 1}, note="first")
|
||||||
|
id2 = tree.commit({"b": 2}, note="second")
|
||||||
|
|
||||||
|
assert tree.nodes[id1]["parent"] is None
|
||||||
|
assert tree.nodes[id2]["parent"] == id1
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkout_returns_correct_data():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
id1 = tree.commit({"val": 42}, note="snap")
|
||||||
|
result = tree.checkout(id1)
|
||||||
|
assert result == {"val": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkout_nonexistent_returns_none():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
assert tree.checkout("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cycle_detection_raises():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
id1 = tree.commit({"a": 1})
|
||||||
|
# Manually introduce a cycle
|
||||||
|
tree.nodes[id1]["parent"] = id1
|
||||||
|
with pytest.raises(ValueError, match="Cycle detected"):
|
||||||
|
tree.commit({"b": 2})
|
||||||
|
|
||||||
|
|
||||||
|
def test_branch_creation_on_detached_head():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
id1 = tree.commit({"a": 1})
|
||||||
|
id2 = tree.commit({"b": 2})
|
||||||
|
# Detach head by checking out a non-tip node
|
||||||
|
tree.checkout(id1)
|
||||||
|
# head_id is now id1, which is no longer a branch tip (main points to id2)
|
||||||
|
id3 = tree.commit({"c": 3})
|
||||||
|
# A new branch should have been created
|
||||||
|
assert len(tree.branches) == 2
|
||||||
|
assert tree.nodes[id3]["parent"] == id1
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_migration():
|
||||||
|
legacy = {
|
||||||
|
"prompt_history": [
|
||||||
|
{"note": "Entry A", "seed": 1},
|
||||||
|
{"note": "Entry B", "seed": 2},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
tree = HistoryTree(legacy)
|
||||||
|
assert len(tree.nodes) == 2
|
||||||
|
assert tree.head_id is not None
|
||||||
|
assert tree.branches["main"] == tree.head_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip():
|
||||||
|
tree = HistoryTree({})
|
||||||
|
tree.commit({"x": 1}, note="test")
|
||||||
|
d = tree.to_dict()
|
||||||
|
tree2 = HistoryTree(d)
|
||||||
|
assert tree2.head_id == tree.head_id
|
||||||
|
assert tree2.nodes == tree.nodes
|
||||||
68
tests/test_json_loader.py
Normal file
68
tests/test_json_loader.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from json_loader import to_float, to_int, get_batch_item, read_json_data
|
||||||
|
|
||||||
|
|
||||||
|
class TestToFloat:
|
||||||
|
def test_valid(self):
|
||||||
|
assert to_float("3.14") == 3.14
|
||||||
|
assert to_float(5) == 5.0
|
||||||
|
|
||||||
|
def test_invalid(self):
|
||||||
|
assert to_float("abc") == 0.0
|
||||||
|
|
||||||
|
def test_none(self):
|
||||||
|
assert to_float(None) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestToInt:
|
||||||
|
def test_valid(self):
|
||||||
|
assert to_int("7") == 7
|
||||||
|
assert to_int(3.9) == 3
|
||||||
|
|
||||||
|
def test_invalid(self):
|
||||||
|
assert to_int("xyz") == 0
|
||||||
|
|
||||||
|
def test_none(self):
|
||||||
|
assert to_int(None) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetBatchItem:
|
||||||
|
def test_valid_index(self):
|
||||||
|
data = {"batch_data": [{"a": 1}, {"a": 2}, {"a": 3}]}
|
||||||
|
assert get_batch_item(data, 2) == {"a": 2}
|
||||||
|
|
||||||
|
def test_clamp_high(self):
|
||||||
|
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
||||||
|
assert get_batch_item(data, 99) == {"a": 2}
|
||||||
|
|
||||||
|
def test_clamp_low(self):
|
||||||
|
data = {"batch_data": [{"a": 1}, {"a": 2}]}
|
||||||
|
assert get_batch_item(data, 0) == {"a": 1}
|
||||||
|
|
||||||
|
def test_no_batch_data(self):
|
||||||
|
data = {"key": "val"}
|
||||||
|
assert get_batch_item(data, 1) == data
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadJsonData:
|
||||||
|
def test_missing_file(self, tmp_path):
|
||||||
|
assert read_json_data(str(tmp_path / "nope.json")) == {}
|
||||||
|
|
||||||
|
def test_invalid_json(self, tmp_path):
|
||||||
|
p = tmp_path / "bad.json"
|
||||||
|
p.write_text("{broken")
|
||||||
|
assert read_json_data(str(p)) == {}
|
||||||
|
|
||||||
|
def test_non_dict_json(self, tmp_path):
|
||||||
|
p = tmp_path / "list.json"
|
||||||
|
p.write_text(json.dumps([1, 2, 3]))
|
||||||
|
assert read_json_data(str(p)) == {}
|
||||||
|
|
||||||
|
def test_valid(self, tmp_path):
|
||||||
|
p = tmp_path / "ok.json"
|
||||||
|
p.write_text(json.dumps({"key": "val"}))
|
||||||
|
assert read_json_data(str(p)) == {"key": "val"}
|
||||||
68
tests/test_utils.py
Normal file
68
tests/test_utils.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Mock streamlit before importing utils
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
sys.modules.setdefault("streamlit", MagicMock())
|
||||||
|
|
||||||
|
from utils import load_json, save_json, get_file_mtime, ALLOWED_BASE_DIR, DEFAULTS
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_json_valid(tmp_path):
|
||||||
|
p = tmp_path / "test.json"
|
||||||
|
data = {"key": "value"}
|
||||||
|
p.write_text(json.dumps(data))
|
||||||
|
result, mtime = load_json(p)
|
||||||
|
assert result == data
|
||||||
|
assert mtime > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_json_missing(tmp_path):
|
||||||
|
p = tmp_path / "nope.json"
|
||||||
|
result, mtime = load_json(p)
|
||||||
|
assert result == DEFAULTS.copy()
|
||||||
|
assert mtime == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_json_invalid(tmp_path):
|
||||||
|
p = tmp_path / "bad.json"
|
||||||
|
p.write_text("{not valid json")
|
||||||
|
result, mtime = load_json(p)
|
||||||
|
assert result == DEFAULTS.copy()
|
||||||
|
assert mtime == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_json_atomic(tmp_path):
|
||||||
|
p = tmp_path / "out.json"
|
||||||
|
data = {"hello": "world"}
|
||||||
|
save_json(p, data)
|
||||||
|
assert p.exists()
|
||||||
|
assert not p.with_suffix(".json.tmp").exists()
|
||||||
|
assert json.loads(p.read_text()) == data
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_json_overwrites(tmp_path):
|
||||||
|
p = tmp_path / "out.json"
|
||||||
|
save_json(p, {"a": 1})
|
||||||
|
save_json(p, {"b": 2})
|
||||||
|
assert json.loads(p.read_text()) == {"b": 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_mtime_existing(tmp_path):
|
||||||
|
p = tmp_path / "f.txt"
|
||||||
|
p.write_text("x")
|
||||||
|
assert get_file_mtime(p) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_mtime_missing(tmp_path):
|
||||||
|
assert get_file_mtime(tmp_path / "missing.txt") == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowed_base_dir_is_set():
|
||||||
|
assert ALLOWED_BASE_DIR is not None
|
||||||
|
assert isinstance(ALLOWED_BASE_DIR, Path)
|
||||||
28
utils.py
28
utils.py
@@ -1,9 +1,18 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
|
|
||||||
|
# --- Magic String Keys ---
|
||||||
|
KEY_BATCH_DATA = "batch_data"
|
||||||
|
KEY_HISTORY_TREE = "history_tree"
|
||||||
|
KEY_PROMPT_HISTORY = "prompt_history"
|
||||||
|
KEY_SEQUENCE_NUMBER = "sequence_number"
|
||||||
|
|
||||||
# Configure logging for the application
|
# Configure logging for the application
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@@ -52,8 +61,8 @@ DEFAULTS = {
|
|||||||
CONFIG_FILE = Path(".editor_config.json")
|
CONFIG_FILE = Path(".editor_config.json")
|
||||||
SNIPPETS_FILE = Path(".editor_snippets.json")
|
SNIPPETS_FILE = Path(".editor_snippets.json")
|
||||||
|
|
||||||
# Restrict directory navigation to this base path (resolve symlinks)
|
# No restriction on directory navigation
|
||||||
ALLOWED_BASE_DIR = Path.cwd().resolve()
|
ALLOWED_BASE_DIR = Path("/").resolve()
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
|
"""Loads the main editor configuration (Favorites, Last Dir, Servers)."""
|
||||||
@@ -96,7 +105,7 @@ def save_snippets(snippets):
|
|||||||
with open(SNIPPETS_FILE, 'w') as f:
|
with open(SNIPPETS_FILE, 'w') as f:
|
||||||
json.dump(snippets, f, indent=4)
|
json.dump(snippets, f, indent=4)
|
||||||
|
|
||||||
def load_json(path):
|
def load_json(path: str | Path) -> tuple[dict[str, Any], float]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
@@ -108,20 +117,23 @@ def load_json(path):
|
|||||||
st.error(f"Error loading JSON: {e}")
|
st.error(f"Error loading JSON: {e}")
|
||||||
return DEFAULTS.copy(), 0
|
return DEFAULTS.copy(), 0
|
||||||
|
|
||||||
def save_json(path, data):
|
def save_json(path: str | Path, data: dict[str, Any]) -> None:
|
||||||
with open(path, 'w') as f:
|
path = Path(path)
|
||||||
|
tmp = path.with_suffix('.json.tmp')
|
||||||
|
with open(tmp, 'w') as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
|
os.replace(tmp, path)
|
||||||
|
|
||||||
def get_file_mtime(path):
|
def get_file_mtime(path: str | Path) -> float:
|
||||||
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
"""Returns the modification time of a file, or 0 if it doesn't exist."""
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return path.stat().st_mtime
|
return path.stat().st_mtime
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def generate_templates(current_dir):
|
def generate_templates(current_dir: Path) -> None:
|
||||||
"""Creates dummy template files if folder is empty."""
|
"""Creates dummy template files if folder is empty."""
|
||||||
save_json(current_dir / "template_i2v.json", DEFAULTS)
|
save_json(current_dir / "template_i2v.json", DEFAULTS)
|
||||||
|
|
||||||
batch_data = {"batch_data": [DEFAULTS.copy(), DEFAULTS.copy()]}
|
batch_data = {KEY_BATCH_DATA: [DEFAULTS.copy(), DEFAULTS.copy()]}
|
||||||
save_json(current_dir / "template_batch.json", batch_data)
|
save_json(current_dir / "template_batch.json", batch_data)
|
||||||
|
|||||||
Reference in New Issue
Block a user