diff --git a/tab_batch.py b/tab_batch.py index 5eb2f0e..86221b1 100644 --- a/tab_batch.py +++ b/tab_batch.py @@ -1,6 +1,6 @@ import streamlit as st import random -from utils import DEFAULTS, save_json, load_json, render_smart_input +from utils import DEFAULTS, save_json, load_json from history_tree import HistoryTree def create_batch_callback(original_filename, current_data, current_dir): @@ -187,38 +187,18 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi seq["reference image path"] = st.text_input("Ref Img", value=seq.get("reference image path", ""), key=f"{prefix}_ri2") seq["flf image path"] = st.text_input("FLF Img", value=seq.get("flf image path", ""), key=f"{prefix}_flfi") - # --- LoRA Settings (SMART INPUTS + SAFE FLOAT FIX) --- + # --- LoRA Settings (Reverted to plain text) --- with st.expander("💊 LoRA Settings"): - meta = st.session_state.get("comfy_meta", {}) - lora_list = meta.get("loras", []) - - lc1, lc2 = st.columns(2) - - # HELPER: Safe conversion - def get_safe_float(val, default=1.0): - try: - return float(val) - except (ValueError, TypeError): - # If it's a string like "", return default 1.0 - return default - + lc1, lc2, lc3 = st.columns(3) with lc1: - st.caption("LoRA 1") - seq["lora 1 high"] = render_smart_input("Model", f"{prefix}_l1h", seq.get("lora 1 high", ""), lora_list) - val = get_safe_float(seq.get("lora 1 low", 1.0)) - seq["lora 1 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l1l")) - + seq["lora 1 high"] = st.text_input("LoRA 1 Name", value=seq.get("lora 1 high", ""), key=f"{prefix}_l1h") + seq["lora 1 low"] = st.text_input("LoRA 1 Strength", value=str(seq.get("lora 1 low", "")), key=f"{prefix}_l1l") with lc2: - st.caption("LoRA 2") - seq["lora 2 high"] = render_smart_input("Model", f"{prefix}_l2h", seq.get("lora 2 high", ""), lora_list) - val = get_safe_float(seq.get("lora 2 low", 1.0)) - seq["lora 2 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l2l")) - - with lc1: - st.caption("LoRA 3") - seq["lora 3 high"] = render_smart_input("Model", f"{prefix}_l3h", seq.get("lora 3 high", ""), lora_list) - val = get_safe_float(seq.get("lora 3 low", 1.0)) - seq["lora 3 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l3l")) + seq["lora 2 high"] = st.text_input("LoRA 2 Name", value=seq.get("lora 2 high", ""), key=f"{prefix}_l2h") + seq["lora 2 low"] = st.text_input("LoRA 2 Strength", value=str(seq.get("lora 2 low", "")), key=f"{prefix}_l2l") + with lc3: + seq["lora 3 high"] = st.text_input("LoRA 3 Name", value=seq.get("lora 3 high", ""), key=f"{prefix}_l3h") + seq["lora 3 low"] = st.text_input("LoRA 3 Strength", value=str(seq.get("lora 3 low", "")), key=f"{prefix}_l3l") # --- CUSTOM PARAMETERS --- st.markdown("---") @@ -283,3 +263,4 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi del st.session_state.restored_indicator st.toast("Batch Saved & Snapshot Created!", icon="🚀") + st.rerun()