Update tab_batch.py

This commit is contained in:
2026-01-02 19:15:21 +01:00
committed by GitHub
parent 9045bbe636
commit 7df798ccd1

View File

@@ -1,6 +1,6 @@
import streamlit as st
import random
from utils import DEFAULTS, save_json, load_json, render_smart_input
from utils import DEFAULTS, save_json, load_json
from history_tree import HistoryTree
def create_batch_callback(original_filename, current_data, current_dir):
@@ -187,38 +187,18 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
seq["reference image path"] = st.text_input("Ref Img", value=seq.get("reference image path", ""), key=f"{prefix}_ri2")
seq["flf image path"] = st.text_input("FLF Img", value=seq.get("flf image path", ""), key=f"{prefix}_flfi")
# --- LoRA Settings (SMART INPUTS + SAFE FLOAT FIX) ---
# --- LoRA Settings (Reverted to plain text) ---
with st.expander("💊 LoRA Settings"):
meta = st.session_state.get("comfy_meta", {})
lora_list = meta.get("loras", [])
lc1, lc2 = st.columns(2)
# HELPER: Safe conversion
def get_safe_float(val, default=1.0):
try:
return float(val)
except (ValueError, TypeError):
# If it's a string like "<lora:name>", return default 1.0
return default
lc1, lc2, lc3 = st.columns(3)
with lc1:
st.caption("LoRA 1")
seq["lora 1 high"] = render_smart_input("Model", f"{prefix}_l1h", seq.get("lora 1 high", ""), lora_list)
val = get_safe_float(seq.get("lora 1 low", 1.0))
seq["lora 1 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l1l"))
seq["lora 1 high"] = st.text_input("LoRA 1 Name", value=seq.get("lora 1 high", ""), key=f"{prefix}_l1h")
seq["lora 1 low"] = st.text_input("LoRA 1 Strength", value=str(seq.get("lora 1 low", "")), key=f"{prefix}_l1l")
with lc2:
st.caption("LoRA 2")
seq["lora 2 high"] = render_smart_input("Model", f"{prefix}_l2h", seq.get("lora 2 high", ""), lora_list)
val = get_safe_float(seq.get("lora 2 low", 1.0))
seq["lora 2 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l2l"))
with lc1:
st.caption("LoRA 3")
seq["lora 3 high"] = render_smart_input("Model", f"{prefix}_l3h", seq.get("lora 3 high", ""), lora_list)
val = get_safe_float(seq.get("lora 3 low", 1.0))
seq["lora 3 low"] = str(st.slider("Strength", 0.0, 2.0, val, 0.05, key=f"{prefix}_l3l"))
seq["lora 2 high"] = st.text_input("LoRA 2 Name", value=seq.get("lora 2 high", ""), key=f"{prefix}_l2h")
seq["lora 2 low"] = st.text_input("LoRA 2 Strength", value=str(seq.get("lora 2 low", "")), key=f"{prefix}_l2l")
with lc3:
seq["lora 3 high"] = st.text_input("LoRA 3 Name", value=seq.get("lora 3 high", ""), key=f"{prefix}_l3h")
seq["lora 3 low"] = st.text_input("LoRA 3 Strength", value=str(seq.get("lora 3 low", "")), key=f"{prefix}_l3l")
# --- CUSTOM PARAMETERS ---
st.markdown("---")
@@ -283,3 +263,4 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
del st.session_state.restored_indicator
st.toast("Batch Saved & Snapshot Created!", icon="🚀")
st.rerun()