diff --git a/tab_batch.py b/tab_batch.py index 881258c..80ae475 100644 --- a/tab_batch.py +++ b/tab_batch.py @@ -1,6 +1,6 @@ import streamlit as st import random -from utils import DEFAULTS, save_json, load_json +from utils import DEFAULTS, save_json, load_json, render_smart_input from history_tree import HistoryTree def create_batch_callback(original_filename, current_data, current_dir): @@ -39,7 +39,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi # --- 1. RESTORED STATE INDICATOR --- if 'restored_indicator' in st.session_state and st.session_state.restored_indicator: st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**") - # ----------------------------------- batch_list = data.get("batch_data", []) @@ -189,11 +188,27 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi seq["reference image path"] = st.text_input("Ref Img", value=seq.get("reference image path", ""), key=f"{prefix}_ri2") seq["flf image path"] = st.text_input("FLF Img", value=seq.get("flf image path", ""), key=f"{prefix}_flfi") - with st.expander("LoRA Settings"): + # --- LoRA Settings (SMART INPUTS) --- + with st.expander("💊 LoRA Settings"): + meta = st.session_state.get("comfy_meta", {}) + lora_list = meta.get("loras", []) + lc1, lc2 = st.columns(2) - for li, lk in enumerate(lora_keys): - with (lc1 if li % 2 == 0 else lc2): - seq[lk] = st.text_input(lk.title(), value=seq.get(lk, ""), key=f"{prefix}_{lk}") + + with lc1: + st.caption("LoRA 1") + seq["lora 1 high"] = render_smart_input("Model", f"{prefix}_l1h", seq.get("lora 1 high", ""), lora_list) + seq["lora 1 low"] = str(st.slider("Strength", 0.0, 2.0, float(seq.get("lora 1 low", 1.0)), 0.05, key=f"{prefix}_l1l")) + + with lc2: + st.caption("LoRA 2") + seq["lora 2 high"] = render_smart_input("Model", f"{prefix}_l2h", seq.get("lora 2 high", ""), lora_list) + seq["lora 2 low"] = str(st.slider("Strength", 0.0, 2.0, float(seq.get("lora 2 low", 1.0)), 0.05, key=f"{prefix}_l2l")) + + with lc1: + st.caption("LoRA 3") + seq["lora 3 high"] = render_smart_input("Model", f"{prefix}_l3h", seq.get("lora 3 high", ""), lora_list) + seq["lora 3 low"] = str(st.slider("Strength", 0.0, 2.0, float(seq.get("lora 3 low", 1.0)), 0.05, key=f"{prefix}_l3l")) # --- CUSTOM PARAMETERS --- st.markdown("---") @@ -243,7 +258,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi if st.button("💾 Save & Snap", use_container_width=True): data["batch_data"] = batch_list - # Commit to Tree tree_data = data.get("history_tree", {}) htree = HistoryTree(tree_data) @@ -255,7 +269,6 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi data["history_tree"] = htree.to_dict() save_json(file_path, data) - # CLEAR THE INDICATOR SINCE WE MOVED FORWARD if 'restored_indicator' in st.session_state: del st.session_state.restored_indicator