Integrate Graphviz for history tree visualization

Added Graphviz integration for history tree visualization and improved comments for clarity.
This commit is contained in:
2026-01-02 01:19:41 +01:00
committed by GitHub
parent 726430ba90
commit 33d7321809

View File

@@ -1,27 +1,43 @@
import streamlit as st import streamlit as st
import random import random
import graphviz
from utils import DEFAULTS, save_json, get_file_mtime from utils import DEFAULTS, save_json, get_file_mtime
from history_tree import HistoryTree
def render_single_editor(data, file_path): def render_single_editor(data, file_path):
# 1. Safety Check: Ensure this isn't a batch file
is_batch_file = "batch_data" in data or isinstance(data, list) is_batch_file = "batch_data" in data or isinstance(data, list)
if is_batch_file: if is_batch_file:
st.info("This is a batch file. Switch to the 'Batch Processor' tab.") st.info("This is a batch file. Switch to the 'Batch Processor' tab.")
return return
# 2. Initialize History Engine
# We load the tree from the JSON. If it doesn't exist, we look for the old list to migrate.
tree_data = data.get("history_tree", {})
if "prompt_history" in data and not tree_data:
# Migration path for old files
tree_data = {"prompt_history": data["prompt_history"]}
htree = HistoryTree(tree_data)
col1, col2 = st.columns([2, 1]) col1, col2 = st.columns([2, 1])
# Unique prefix for this file's widgets + Version Token # 3. Generate Unique Key for Widgets
# We append 'ui_reset_token' so that changing it forces Streamlit to re-render all inputs
fk = f"{file_path.name}_v{st.session_state.ui_reset_token}" fk = f"{file_path.name}_v{st.session_state.ui_reset_token}"
# --- FORM --- # ==============================================================================
# LEFT COLUMN: THE EDITOR FORM
# ==============================================================================
with col1: with col1:
# --- PROMPTS ---
with st.expander("🌍 General Prompts (Global Layer)", expanded=False): with st.expander("🌍 General Prompts (Global Layer)", expanded=False):
gen_prompt = st.text_area("General Prompt", value=data.get("general_prompt", ""), height=100, key=f"{fk}_gp") gen_prompt = st.text_area("General Prompt", value=data.get("general_prompt", ""), height=100, key=f"{fk}_gp")
gen_negative = st.text_area("General Negative", value=data.get("general_negative", DEFAULTS["general_negative"]), height=100, key=f"{fk}_gn") gen_negative = st.text_area("General Negative", value=data.get("general_negative", DEFAULTS["general_negative"]), height=100, key=f"{fk}_gn")
st.write("📝 **Specific Prompts**") st.write("📝 **Specific Prompts**")
current_prompt_val = data.get("current_prompt", "") current_prompt_val = data.get("current_prompt", "")
# Logic to append snippets from sidebar
if 'append_prompt' in st.session_state: if 'append_prompt' in st.session_state:
current_prompt_val = (current_prompt_val.strip() + ", " + st.session_state.append_prompt).strip(', ') current_prompt_val = (current_prompt_val.strip() + ", " + st.session_state.append_prompt).strip(', ')
del st.session_state.append_prompt del st.session_state.append_prompt
@@ -29,14 +45,14 @@ def render_single_editor(data, file_path):
new_prompt = st.text_area("Specific Prompt", value=current_prompt_val, height=150, key=f"{fk}_sp") new_prompt = st.text_area("Specific Prompt", value=current_prompt_val, height=150, key=f"{fk}_sp")
new_negative = st.text_area("Specific Negative", value=data.get("negative", ""), height=100, key=f"{fk}_sn") new_negative = st.text_area("Specific Negative", value=data.get("negative", ""), height=100, key=f"{fk}_sn")
# Seed # --- SEED ---
col_seed_val, col_seed_btn = st.columns([4, 1]) col_seed_val, col_seed_btn = st.columns([4, 1])
seed_key = f"{fk}_seed" seed_key = f"{fk}_seed"
with col_seed_btn: with col_seed_btn:
st.write("") st.write("")
st.write("") st.write("")
if st.button("🎲 Randomize", key=f"{fk}_rand"): if st.button("🎲", key=f"{fk}_rand", help="Randomize Seed"):
st.session_state[seed_key] = random.randint(0, 999999999999) st.session_state[seed_key] = random.randint(0, 999999999999)
st.rerun() st.rerun()
@@ -45,7 +61,7 @@ def render_single_editor(data, file_path):
new_seed = st.number_input("Seed", value=seed_val, step=1, min_value=0, format="%d", key=seed_key) new_seed = st.number_input("Seed", value=seed_val, step=1, min_value=0, format="%d", key=seed_key)
data["seed"] = new_seed data["seed"] = new_seed
# LoRAs # --- LORAS ---
st.subheader("LoRAs") st.subheader("LoRAs")
l_col1, l_col2 = st.columns(2) l_col1, l_col2 = st.columns(2)
loras = {} loras = {}
@@ -54,19 +70,20 @@ def render_single_editor(data, file_path):
with (l_col1 if i % 2 == 0 else l_col2): with (l_col1 if i % 2 == 0 else l_col2):
loras[k] = st.text_input(k.title(), value=data.get(k, ""), key=f"{fk}_{k}") loras[k] = st.text_input(k.title(), value=data.get(k, ""), key=f"{fk}_{k}")
# Settings # --- STANDARD SETTINGS ---
st.subheader("Settings") st.subheader("Settings")
spec_fields = {} spec_fields = {}
spec_fields["camera"] = st.text_input("Camera", value=str(data.get("camera", DEFAULTS["camera"])), key=f"{fk}_cam") spec_fields["camera"] = st.text_input("Camera", value=str(data.get("camera", DEFAULTS["camera"])), key=f"{fk}_cam")
spec_fields["flf"] = st.text_input("FLF", value=str(data.get("flf", DEFAULTS["flf"])), key=f"{fk}_flf") spec_fields["flf"] = st.text_input("FLF", value=str(data.get("flf", DEFAULTS["flf"])), key=f"{fk}_flf")
# Explicitly track standard setting keys to exclude them from custom list # Define what is "Standard" so Custom Param logic knows what to ignore
standard_keys = { standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
"camera", "flf", "batch_data", "prompt_history", "sequence_number", "ui_reset_token" "camera", "flf", "batch_data", "prompt_history", "history_tree", "sequence_number", "ui_reset_token"
} }
standard_keys.update(lora_keys) # Add LoRAs to reserved list standard_keys.update(lora_keys)
# Conditional Logic for VACE vs I2V
if "vace" in file_path.name: if "vace" in file_path.name:
vace_keys = ["frame_to_skip", "input_a_frames", "input_b_frames", "reference switch", "vace schedule", "reference path", "video file path", "reference image path"] vace_keys = ["frame_to_skip", "input_a_frames", "input_b_frames", "reference switch", "vace schedule", "reference path", "video file path", "reference image path"]
standard_keys.update(vace_keys) standard_keys.update(vace_keys)
@@ -78,6 +95,7 @@ def render_single_editor(data, file_path):
spec_fields["vace schedule"] = st.number_input("VACE Schedule", value=int(data.get("vace schedule", 1)), key=f"{fk}_vsc") spec_fields["vace schedule"] = st.number_input("VACE Schedule", value=int(data.get("vace schedule", 1)), key=f"{fk}_vsc")
for f in ["reference path", "video file path", "reference image path"]: for f in ["reference path", "video file path", "reference image path"]:
spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}") spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}")
elif "i2v" in file_path.name: elif "i2v" in file_path.name:
i2v_keys = ["reference image path", "flf image path", "video file path"] i2v_keys = ["reference image path", "flf image path", "video file path"]
standard_keys.update(i2v_keys) standard_keys.update(i2v_keys)
@@ -85,13 +103,12 @@ def render_single_editor(data, file_path):
for f in i2v_keys: for f in i2v_keys:
spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}") spec_fields[f] = st.text_input(f.title(), value=str(data.get(f, "")), key=f"{fk}_{f}")
# --- CUSTOM PARAMETERS LOGIC (FIXED) --- # --- CUSTOM PARAMETERS ---
st.markdown("---") st.markdown("---")
st.subheader("🔧 Custom Parameters") st.subheader("🔧 Custom Parameters")
# Filter keys: Only those NOT in the standard set # Filter: Any key in data that is NOT in standard_keys is a Custom Key
custom_keys = [k for k in data.keys() if k not in standard_keys] custom_keys = [k for k in data.keys() if k not in standard_keys]
keys_to_remove = [] keys_to_remove = []
if custom_keys: if custom_keys:
@@ -106,7 +123,7 @@ def render_single_editor(data, file_path):
else: else:
st.caption("No custom keys added.") st.caption("No custom keys added.")
# Add New Key Interface # Add New Interface
with st.expander(" Add New Parameter"): with st.expander(" Add New Parameter"):
nk_col, nv_col = st.columns(2) nk_col, nv_col = st.columns(2)
new_k = nk_col.text_input("Key Name", key=f"{fk}_new_k") new_k = nk_col.text_input("Key Name", key=f"{fk}_new_k")
@@ -119,120 +136,117 @@ def render_single_editor(data, file_path):
elif new_k in data: elif new_k in data:
st.error(f"Key '{new_k}' already exists!") st.error(f"Key '{new_k}' already exists!")
# Apply Removals # Process removals
if keys_to_remove: if keys_to_remove:
for k in keys_to_remove: for k in keys_to_remove:
del data[k] del data[k]
st.rerun() st.rerun()
# --- ACTIONS & HISTORY --- # ==============================================================================
# RIGHT COLUMN: ACTIONS & TIMELINE
# ==============================================================================
with col2: with col2:
# 1. Capture State (Form -> Dict)
current_state = { current_state = {
"general_prompt": gen_prompt, "general_negative": gen_negative, "general_prompt": gen_prompt, "general_negative": gen_negative,
"current_prompt": new_prompt, "negative": new_negative, "current_prompt": new_prompt, "negative": new_negative,
"seed": new_seed, **loras, **spec_fields "seed": new_seed, **loras, **spec_fields
} }
# MERGE CUSTOM KEYS # Merge Custom Keys into current_state so they are saved
for k in custom_keys: for k in custom_keys:
if k not in keys_to_remove: if k not in keys_to_remove:
current_state[k] = data[k] current_state[k] = data[k]
st.session_state.single_editor_cache = current_state st.session_state.single_editor_cache = current_state
# 2. Disk Operations
st.subheader("Actions") st.subheader("Actions")
current_disk_mtime = get_file_mtime(file_path) current_disk_mtime = get_file_mtime(file_path)
is_conflict = current_disk_mtime > st.session_state.last_mtime is_conflict = current_disk_mtime > st.session_state.last_mtime
if is_conflict: if is_conflict:
st.error("⚠️ CONFLICT: Disk changed!") st.error("⚠️ CONFLICT: Disk changed!")
if st.button("Force Save"): c1, c2 = st.columns(2)
if c1.button("Force Save"):
data.update(current_state) data.update(current_state)
st.session_state.last_mtime = save_json(file_path, data) st.session_state.last_mtime = save_json(file_path, data)
st.session_state.data_cache = data st.session_state.data_cache = data
st.toast("Saved!", icon="⚠️") st.toast("Saved!", icon="⚠️")
st.rerun() st.rerun()
if st.button("Reload File"): if c2.button("Reload"):
st.session_state.loaded_file = None st.session_state.loaded_file = None
st.rerun() st.rerun()
else: else:
if st.button("💾 Update File", use_container_width=True): if st.button("💾 Quick Save (Update Disk)", use_container_width=True):
data.update(current_state) data.update(current_state)
st.session_state.last_mtime = save_json(file_path, data) st.session_state.last_mtime = save_json(file_path, data)
st.session_state.data_cache = data st.session_state.data_cache = data
st.toast("Updated!", icon="") st.toast("Saved!", icon="")
st.markdown("---")
archive_note = st.text_input("Archive Note")
if st.button("📦 Snapshot to History", use_container_width=True):
entry = {"note": archive_note if archive_note else "Snapshot", **current_state}
if "prompt_history" not in data: data["prompt_history"] = []
data["prompt_history"].insert(0, entry)
data.update(entry)
st.session_state.last_mtime = save_json(file_path, data)
st.session_state.data_cache = data
st.toast("Archived!", icon="📦")
st.rerun()
# --- FULL HISTORY PANEL ---
st.markdown("---") st.markdown("---")
st.subheader("History")
history = data.get("prompt_history", [])
if not history: # 3. HISTORY TREE TIMELINE
st.caption("No history yet.") st.subheader("Timeline & Branching")
# Render Graphviz Tree
try:
graph_dot = htree.generate_graphviz()
st.graphviz_chart(graph_dot, use_container_width=True)
except Exception as e:
st.error(f"Graph Error: {e}")
for idx, h in enumerate(history): # Snapshot Controls
note = h.get('note', 'No Note') st.caption("Create Snapshot (Commits current state to timeline)")
c_col1, c_col2 = st.columns([3, 1])
commit_note = c_col1.text_input("Snapshot Note", placeholder="e.g. Added fog", label_visibility="collapsed", key=f"{fk}_snote")
if c_col2.button("📷 Snap", help="Save Snapshot"):
# Prepare full snapshot data
full_snapshot = data.copy()
full_snapshot.update(current_state)
with st.container(): # Clean recursive keys
if st.session_state.edit_history_idx == idx: if "history_tree" in full_snapshot: del full_snapshot["history_tree"]
with st.expander(f"📝 Editing: {note}", expanded=True): if "prompt_history" in full_snapshot: del full_snapshot["prompt_history"]
edit_note = st.text_input("Note", value=note, key=f"h_en_{idx}")
edit_seed = st.number_input("Seed", value=int(h.get('seed', 0)), key=f"h_es_{idx}") # Commit to Tree
edit_gp = st.text_area("General P", value=h.get('general_prompt', ''), height=60, key=f"h_egp_{idx}") htree.commit(full_snapshot, note=commit_note if commit_note else "Snapshot")
edit_sp = st.text_area("Specific P", value=h.get('prompt', ''), height=100, key=f"h_esp_{idx}")
# Save Tree back to main Data object
hc1, hc2 = st.columns([1, 4]) data["history_tree"] = htree.to_dict()
if hc1.button("💾 Save", key=f"h_save_{idx}"): if "prompt_history" in data: del data["prompt_history"] # Clean legacy
h.update({
'note': edit_note, 'seed': edit_seed, save_json(file_path, data)
'general_prompt': edit_gp, st.session_state.ui_reset_token += 1
'prompt': edit_sp st.toast("Snapshot created!", icon="📸")
}) st.rerun()
st.session_state.last_mtime = save_json(file_path, data)
st.session_state.data_cache = data
st.session_state.edit_history_idx = None
st.rerun()
if hc2.button("Cancel", key=f"h_can_{idx}"):
st.session_state.edit_history_idx = None
st.rerun()
else:
with st.expander(f"#{idx+1}: {note}"):
st.caption(f"Seed: {h.get('seed', 0)}")
st.text(f"SPEC: {h.get('prompt', '')[:40]}...")
view_data = {k:v for k,v in h.items() if k not in ['prompt', 'negative', 'general_prompt', 'general_negative', 'note']}
st.json(view_data, expanded=False)
bh1, bh2, bh3 = st.columns([2, 1, 1]) st.divider()
if bh1.button("Restore", key=f"h_rest_{idx}", use_container_width=True): # Restore Controls
data.update(h) all_nodes = htree.nodes.values()
if 'prompt' in h: data['current_prompt'] = h['prompt'] # Sort chronologically reverse
st.session_state.last_mtime = save_json(file_path, data) sorted_nodes = sorted(all_nodes, key=lambda x: x["timestamp"], reverse=True)
st.session_state.data_cache = data
st.session_state.ui_reset_token += 1 node_options = {n["id"]: f"{n.get('note','Step')} ({n['id']})" for n in sorted_nodes}
st.toast("Restored!", icon="")
st.rerun() if not node_options:
st.caption("No timeline history yet.")
if bh2.button("✏️", key=f"h_edit_{idx}"): else:
st.session_state.edit_history_idx = idx selected_node_id = st.selectbox("Jump to Time:", options=list(node_options.keys()), format_func=lambda x: node_options[x], key=f"{fk}_jumpbox")
st.rerun()
if st.button("⏪ Jump / Restore", use_container_width=True):
if bh3.button("🗑️", key=f"h_del_{idx}"): restored_data = htree.checkout(selected_node_id)
history.pop(idx) if restored_data:
st.session_state.last_mtime = save_json(file_path, data) # 1. Update working data
st.session_state.data_cache = data data.update(restored_data)
st.rerun()
# 2. Save the HEAD move
data["history_tree"] = htree.to_dict()
save_json(file_path, data)
# 3. Force UI Reset
st.session_state.ui_reset_token += 1
st.toast(f"Jumped to {node_options[selected_node_id]}", icon="")
st.rerun()