diff --git a/tab_batch.py b/tab_batch.py index 9c5695c..881258c 100644 --- a/tab_batch.py +++ b/tab_batch.py @@ -1,7 +1,7 @@ import streamlit as st import random from utils import DEFAULTS, save_json, load_json -from history_tree import HistoryTree # <--- NEW IMPORT +from history_tree import HistoryTree def create_batch_callback(original_filename, current_data, current_dir): new_name = f"batch_{original_filename}" @@ -13,13 +13,12 @@ def create_batch_callback(original_filename, current_data, current_dir): first_item = current_data.copy() if "prompt_history" in first_item: del first_item["prompt_history"] - if "history_tree" in first_item: del first_item["history_tree"] # Don't duplicate tree + if "history_tree" in first_item: del first_item["history_tree"] first_item["sequence_number"] = 1 new_data = { "batch_data": [first_item], - # Initialize empty history for the new file "history_tree": {}, "prompt_history": [] } @@ -37,6 +36,11 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir)) return + # --- 1. RESTORED STATE INDICATOR --- + if 'restored_indicator' in st.session_state and st.session_state.restored_indicator: + st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**") + # ----------------------------------- + batch_list = data.get("batch_data", []) # --- ADD NEW SEQUENCE AREA --- @@ -50,10 +54,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi src_data, _ = load_json(current_dir / src_name) with ac2: - # Legacy history support for import source src_hist = src_data.get("prompt_history", []) h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else [] - sel_hist = st.selectbox("History Entry:", h_opts, key="batch_src_hist") + sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist") bc1, bc2, bc3 = st.columns(3) @@ -63,7 +66,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"])) new_item["sequence_number"] = max_seq + 1 - # Cleanup metadata keys from item for k in ["prompt_history", "history_tree", "note", "loras"]: if k in new_item: del new_item[k] @@ -82,7 +84,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi item.update(flat) add_sequence(item) - if bc3.button("➕ From History (Legacy)", use_container_width=True, disabled=not src_hist): + if bc3.button("➕ From History", use_container_width=True, disabled=not src_hist): if sel_hist: idx = int(sel_hist.split(":")[0].replace("#", "")) - 1 item = DEFAULTS.copy() @@ -96,7 +98,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi st.markdown("---") st.info(f"Batch contains {len(batch_list)} sequences.") - # Standard keys to exclude from Custom List in Batch lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"] standard_keys = { "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", @@ -166,7 +167,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi st.session_state[seed_key] = random.randint(0, 999999999999) st.rerun() with s_row1: - # Prefer session state if dice was clicked current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0))) val = st.number_input("Seed", value=current_seed, key=seed_key) seq["seed"] = val @@ -195,7 +195,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with (lc1 if li % 2 == 0 else lc2): seq[lk] = st.text_input(lk.title(), value=seq.get(lk, ""), key=f"{prefix}_{lk}") - # --- CUSTOM PARAMETERS (BATCH) --- + # --- CUSTOM PARAMETERS --- st.markdown("---") st.caption("🔧 Custom Parameters") @@ -220,7 +220,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if st.button("Add", key=f"{prefix}_add_cust"): if new_k and new_k not in seq: seq[new_k] = new_v - # Autosave on structure change save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() @@ -242,24 +241,23 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi with col_save: if st.button("💾 Save & Snap", use_container_width=True): - # 1. Update Data with the latest list data["batch_data"] = batch_list - # 2. Init/Load History Engine + # Commit to Tree tree_data = data.get("history_tree", {}) htree = HistoryTree(tree_data) - # 3. Create Clean Snapshot - # Remove the recursive history tree from the payload we are saving INTO the history tree snapshot_payload = data.copy() if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"] - # 4. Commit htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") - # 5. Write Tree back to Data data["history_tree"] = htree.to_dict() - - # 6. Disk Save save_json(file_path, data) + + # CLEAR THE INDICATOR SINCE WE MOVED FORWARD + if 'restored_indicator' in st.session_state: + del st.session_state.restored_indicator + st.toast("Batch Saved & Snapshot Created!", icon="🚀") + st.rerun()