From cbe2355ef6c9d7a63d42482a3f1cdabdb0f735f5 Mon Sep 17 00:00:00 2001 From: ethanfel Date: Fri, 2 Jan 2026 13:16:17 +0100 Subject: [PATCH] Update tab_batch.py --- tab_batch.py | 61 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/tab_batch.py b/tab_batch.py index f91de00..9c5695c 100644 --- a/tab_batch.py +++ b/tab_batch.py @@ -1,6 +1,7 @@ import streamlit as st import random from utils import DEFAULTS, save_json, load_json +from history_tree import HistoryTree # <--- NEW IMPORT def create_batch_callback(original_filename, current_data, current_dir): new_name = f"batch_{original_filename}" @@ -12,11 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir): first_item = current_data.copy() if "prompt_history" in first_item: del first_item["prompt_history"] + if "history_tree" in first_item: del first_item["history_tree"] # Don't duplicate tree + first_item["sequence_number"] = 1 new_data = { "batch_data": [first_item], - "prompt_history": current_data.get("prompt_history", []) + # Initialize empty history for the new file + "history_tree": {}, + "prompt_history": [] } save_json(new_path, new_data) @@ -34,6 +39,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi batch_list = data.get("batch_data", []) + # --- ADD NEW SEQUENCE AREA --- st.subheader("Add New Sequence") ac1, ac2 = st.columns(2) @@ -44,6 +50,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi src_data, _ = load_json(current_dir / src_name) with ac2: + # Legacy history support for import source src_hist = src_data.get("prompt_history", []) h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else [] sel_hist = st.selectbox("History Entry:", h_opts, key="batch_src_hist") @@ -55,7 +62,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi for s in batch_list: if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"])) new_item["sequence_number"] = max_seq + 1 - for k in ["prompt_history", "note", "loras"]: + + # Cleanup metadata keys from item + for k in ["prompt_history", "history_tree", "note", "loras"]: if k in new_item: del new_item[k] batch_list.append(new_item) @@ -73,7 +82,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi item.update(flat) add_sequence(item) - if bc3.button("➕ From History", use_container_width=True, disabled=not src_hist): + if bc3.button("➕ From History (Legacy)", use_container_width=True, disabled=not src_hist): if sel_hist: idx = int(sel_hist.split(":")[0].replace("#", "")) - 1 item = DEFAULTS.copy() @@ -83,6 +92,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi item.update(h_item["loras"]) add_sequence(item) + # --- RENDER LIST --- st.markdown("---") st.info(f"Batch contains {len(batch_list)} sequences.") @@ -110,7 +120,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data item.update(flat) item["sequence_number"] = seq_num - if "prompt_history" in item: del item["prompt_history"] + + for k in ["prompt_history", "history_tree"]: + if k in item: del item[k] + batch_list[i] = item data["batch_data"] = batch_list save_json(file_path, data) @@ -121,6 +134,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if b2.button("↖️ Promote to Single", key=f"{prefix}_prom"): single_data = seq.copy() single_data["prompt_history"] = data.get("prompt_history", []) + single_data["history_tree"] = data.get("history_tree", {}) if "sequence_number" in single_data: del single_data["sequence_number"] save_json(file_path, single_data) st.toast("Converted to Single!", icon="✅") @@ -152,7 +166,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi st.session_state[seed_key] = random.randint(0, 999999999999) st.rerun() with s_row1: - val = st.number_input("Seed", value=int(seq.get("seed", 0)), key=seed_key) + # Prefer session state if dice was clicked + current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0))) + val = st.number_input("Seed", value=current_seed, key=seed_key) seq["seed"] = val seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam") @@ -204,6 +220,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if st.button("Add", key=f"{prefix}_add_cust"): if new_k and new_k not in seq: seq[new_k] = new_v + # Autosave on structure change save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() @@ -216,7 +233,33 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi st.rerun() st.markdown("---") - if st.button("💾 Save Batch Changes"): - data["batch_data"] = batch_list - save_json(file_path, data) - st.toast("Batch saved!", icon="🚀") + + # --- SAVE ACTIONS WITH HISTORY COMMIT --- + col_save, col_note = st.columns([1, 2]) + + with col_note: + commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3") + + with col_save: + if st.button("💾 Save & Snap", use_container_width=True): + # 1. Update Data with the latest list + data["batch_data"] = batch_list + + # 2. Init/Load History Engine + tree_data = data.get("history_tree", {}) + htree = HistoryTree(tree_data) + + # 3. Create Clean Snapshot + # Remove the recursive history tree from the payload we are saving INTO the history tree + snapshot_payload = data.copy() + if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"] + + # 4. Commit + htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") + + # 5. Write Tree back to Data + data["history_tree"] = htree.to_dict() + + # 6. Disk Save + save_json(file_path, data) + st.toast("Batch Saved & Snapshot Created!", icon="🚀")