Update tab_batch.py

This commit is contained in:
2026-01-02 13:26:57 +01:00
committed by GitHub
parent e5db5a5b55
commit af174839e8

View File

@@ -1,7 +1,7 @@
import streamlit as st import streamlit as st
import random import random
from utils import DEFAULTS, save_json, load_json from utils import DEFAULTS, save_json, load_json
from history_tree import HistoryTree # <--- NEW IMPORT from history_tree import HistoryTree
def create_batch_callback(original_filename, current_data, current_dir): def create_batch_callback(original_filename, current_data, current_dir):
new_name = f"batch_{original_filename}" new_name = f"batch_{original_filename}"
@@ -13,13 +13,12 @@ def create_batch_callback(original_filename, current_data, current_dir):
first_item = current_data.copy() first_item = current_data.copy()
if "prompt_history" in first_item: del first_item["prompt_history"] if "prompt_history" in first_item: del first_item["prompt_history"]
if "history_tree" in first_item: del first_item["history_tree"] # Don't duplicate tree if "history_tree" in first_item: del first_item["history_tree"]
first_item["sequence_number"] = 1 first_item["sequence_number"] = 1
new_data = { new_data = {
"batch_data": [first_item], "batch_data": [first_item],
# Initialize empty history for the new file
"history_tree": {}, "history_tree": {},
"prompt_history": [] "prompt_history": []
} }
@@ -37,6 +36,11 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir)) st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir))
return return
# --- 1. RESTORED STATE INDICATOR ---
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
# -----------------------------------
batch_list = data.get("batch_data", []) batch_list = data.get("batch_data", [])
# --- ADD NEW SEQUENCE AREA --- # --- ADD NEW SEQUENCE AREA ---
@@ -50,10 +54,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
src_data, _ = load_json(current_dir / src_name) src_data, _ = load_json(current_dir / src_name)
with ac2: with ac2:
# Legacy history support for import source
src_hist = src_data.get("prompt_history", []) src_hist = src_data.get("prompt_history", [])
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else [] h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
sel_hist = st.selectbox("History Entry:", h_opts, key="batch_src_hist") sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
bc1, bc2, bc3 = st.columns(3) bc1, bc2, bc3 = st.columns(3)
@@ -63,7 +66,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"])) if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
new_item["sequence_number"] = max_seq + 1 new_item["sequence_number"] = max_seq + 1
# Cleanup metadata keys from item
for k in ["prompt_history", "history_tree", "note", "loras"]: for k in ["prompt_history", "history_tree", "note", "loras"]:
if k in new_item: del new_item[k] if k in new_item: del new_item[k]
@@ -82,7 +84,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
item.update(flat) item.update(flat)
add_sequence(item) add_sequence(item)
if bc3.button(" From History (Legacy)", use_container_width=True, disabled=not src_hist): if bc3.button(" From History", use_container_width=True, disabled=not src_hist):
if sel_hist: if sel_hist:
idx = int(sel_hist.split(":")[0].replace("#", "")) - 1 idx = int(sel_hist.split(":")[0].replace("#", "")) - 1
item = DEFAULTS.copy() item = DEFAULTS.copy()
@@ -96,7 +98,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
st.markdown("---") st.markdown("---")
st.info(f"Batch contains {len(batch_list)} sequences.") st.info(f"Batch contains {len(batch_list)} sequences.")
# Standard keys to exclude from Custom List in Batch
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"] lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
standard_keys = { standard_keys = {
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
@@ -166,7 +167,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
st.session_state[seed_key] = random.randint(0, 999999999999) st.session_state[seed_key] = random.randint(0, 999999999999)
st.rerun() st.rerun()
with s_row1: with s_row1:
# Prefer session state if dice was clicked
current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0))) current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0)))
val = st.number_input("Seed", value=current_seed, key=seed_key) val = st.number_input("Seed", value=current_seed, key=seed_key)
seq["seed"] = val seq["seed"] = val
@@ -195,7 +195,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with (lc1 if li % 2 == 0 else lc2): with (lc1 if li % 2 == 0 else lc2):
seq[lk] = st.text_input(lk.title(), value=seq.get(lk, ""), key=f"{prefix}_{lk}") seq[lk] = st.text_input(lk.title(), value=seq.get(lk, ""), key=f"{prefix}_{lk}")
# --- CUSTOM PARAMETERS (BATCH) --- # --- CUSTOM PARAMETERS ---
st.markdown("---") st.markdown("---")
st.caption("🔧 Custom Parameters") st.caption("🔧 Custom Parameters")
@@ -220,7 +220,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
if st.button("Add", key=f"{prefix}_add_cust"): if st.button("Add", key=f"{prefix}_add_cust"):
if new_k and new_k not in seq: if new_k and new_k not in seq:
seq[new_k] = new_v seq[new_k] = new_v
# Autosave on structure change
save_json(file_path, data) save_json(file_path, data)
st.session_state.ui_reset_token += 1 st.session_state.ui_reset_token += 1
st.rerun() st.rerun()
@@ -242,24 +241,23 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
with col_save: with col_save:
if st.button("💾 Save & Snap", use_container_width=True): if st.button("💾 Save & Snap", use_container_width=True):
# 1. Update Data with the latest list
data["batch_data"] = batch_list data["batch_data"] = batch_list
# 2. Init/Load History Engine # Commit to Tree
tree_data = data.get("history_tree", {}) tree_data = data.get("history_tree", {})
htree = HistoryTree(tree_data) htree = HistoryTree(tree_data)
# 3. Create Clean Snapshot
# Remove the recursive history tree from the payload we are saving INTO the history tree
snapshot_payload = data.copy() snapshot_payload = data.copy()
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"] if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
# 4. Commit
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
# 5. Write Tree back to Data
data["history_tree"] = htree.to_dict() data["history_tree"] = htree.to_dict()
# 6. Disk Save
save_json(file_path, data) save_json(file_path, data)
# CLEAR THE INDICATOR SINCE WE MOVED FORWARD
if 'restored_indicator' in st.session_state:
del st.session_state.restored_indicator
st.toast("Batch Saved & Snapshot Created!", icon="🚀") st.toast("Batch Saved & Snapshot Created!", icon="🚀")
st.rerun()