334 lines
16 KiB
Python
334 lines
16 KiB
Python
import streamlit as st
|
||
import random
|
||
import copy
|
||
from utils import DEFAULTS, save_json, load_json
|
||
from history_tree import HistoryTree
|
||
|
||
def create_batch_callback(original_filename, current_data, current_dir):
|
||
new_name = f"batch_{original_filename}"
|
||
new_path = current_dir / new_name
|
||
|
||
if new_path.exists():
|
||
st.toast(f"File {new_name} already exists!", icon="⚠️")
|
||
return
|
||
|
||
first_item = current_data.copy()
|
||
if "prompt_history" in first_item: del first_item["prompt_history"]
|
||
if "history_tree" in first_item: del first_item["history_tree"]
|
||
|
||
first_item["sequence_number"] = 1
|
||
|
||
new_data = {
|
||
"batch_data": [first_item],
|
||
"history_tree": {},
|
||
"prompt_history": []
|
||
}
|
||
|
||
save_json(new_path, new_data)
|
||
st.toast(f"Created {new_name}", icon="✨")
|
||
st.session_state.file_selector = new_name
|
||
|
||
|
||
def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name):
|
||
is_batch_file = "batch_data" in data or isinstance(data, list)
|
||
|
||
if not is_batch_file:
|
||
st.warning("This is a Single file. To use Batch mode, create a copy.")
|
||
st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir))
|
||
return
|
||
|
||
if 'restored_indicator' in st.session_state and st.session_state.restored_indicator:
|
||
st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**")
|
||
|
||
batch_list = data.get("batch_data", [])
|
||
|
||
# --- ADD NEW SEQUENCE AREA ---
|
||
st.subheader("Add New Sequence")
|
||
ac1, ac2 = st.columns(2)
|
||
|
||
with ac1:
|
||
file_options = [f.name for f in json_files]
|
||
d_idx = file_options.index(selected_file_name) if selected_file_name in file_options else 0
|
||
src_name = st.selectbox("Source File:", file_options, index=d_idx, key="batch_src_file")
|
||
src_data, _ = load_json(current_dir / src_name)
|
||
|
||
with ac2:
|
||
src_hist = src_data.get("prompt_history", [])
|
||
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
||
sel_hist = st.selectbox("History Entry (Legacy):", h_opts, key="batch_src_hist")
|
||
|
||
bc1, bc2, bc3 = st.columns(3)
|
||
|
||
def add_sequence(new_item):
|
||
max_seq = 0
|
||
for s in batch_list:
|
||
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
||
new_item["sequence_number"] = max_seq + 1
|
||
|
||
for k in ["prompt_history", "history_tree", "note", "loras"]:
|
||
if k in new_item: del new_item[k]
|
||
|
||
batch_list.append(new_item)
|
||
data["batch_data"] = batch_list
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.rerun()
|
||
|
||
if bc1.button("➕ Add Empty", use_container_width=True):
|
||
add_sequence(DEFAULTS.copy())
|
||
|
||
if bc2.button("➕ From File", use_container_width=True, help=f"Copy {src_name}"):
|
||
item = DEFAULTS.copy()
|
||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||
item.update(flat)
|
||
add_sequence(item)
|
||
|
||
if bc3.button("➕ From History", use_container_width=True, disabled=not src_hist):
|
||
if sel_hist:
|
||
idx = int(sel_hist.split(":")[0].replace("#", "")) - 1
|
||
item = DEFAULTS.copy()
|
||
h_item = src_hist[idx]
|
||
item.update(h_item)
|
||
if "loras" in h_item and isinstance(h_item["loras"], dict):
|
||
item.update(h_item["loras"])
|
||
add_sequence(item)
|
||
|
||
# --- RENDER LIST ---
|
||
st.markdown("---")
|
||
st.info(f"Batch contains {len(batch_list)} sequences.")
|
||
|
||
# Updated LoRA keys to match new logic
|
||
lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"]
|
||
standard_keys = {
|
||
"general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed",
|
||
"camera", "flf", "sequence_number"
|
||
}
|
||
standard_keys.update(lora_keys)
|
||
standard_keys.update([
|
||
"frame_to_skip", "input_a_frames", "input_b_frames", "reference switch", "vace schedule",
|
||
"reference path", "video file path", "reference image path", "flf image path"
|
||
])
|
||
|
||
for i, seq in enumerate(batch_list):
|
||
seq_num = seq.get("sequence_number", i+1)
|
||
prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}"
|
||
|
||
with st.expander(f"🎬 Sequence #{seq_num}", expanded=False):
|
||
# --- ACTION ROW ---
|
||
act_c1, act_c2, act_c3, act_c4 = st.columns([1.2, 1.8, 1.2, 0.5])
|
||
|
||
# 1. Copy Source
|
||
with act_c1:
|
||
if st.button(f"📥 Copy {src_name}", key=f"{prefix}_copy", use_container_width=True):
|
||
item = DEFAULTS.copy()
|
||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||
item.update(flat)
|
||
item["sequence_number"] = seq_num
|
||
for k in ["prompt_history", "history_tree"]:
|
||
if k in item: del item[k]
|
||
batch_list[i] = item
|
||
data["batch_data"] = batch_list
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.toast("Copied!", icon="📥")
|
||
st.rerun()
|
||
|
||
# 2. Cloning Tools
|
||
with act_c2:
|
||
cl_1, cl_2 = st.columns(2)
|
||
if cl_1.button("👯 Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True):
|
||
new_seq = seq.copy()
|
||
max_sn = 0
|
||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||
new_seq["sequence_number"] = max_sn + 1
|
||
batch_list.insert(i + 1, new_seq)
|
||
data["batch_data"] = batch_list
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.toast("Cloned to Next!", icon="👯")
|
||
st.rerun()
|
||
|
||
if cl_2.button("⏬ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True):
|
||
new_seq = seq.copy()
|
||
max_sn = 0
|
||
for s in batch_list: max_sn = max(max_sn, int(s.get("sequence_number", 0)))
|
||
new_seq["sequence_number"] = max_sn + 1
|
||
batch_list.append(new_seq)
|
||
data["batch_data"] = batch_list
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.toast("Cloned to End!", icon="⏬")
|
||
st.rerun()
|
||
|
||
# 3. Promote
|
||
with act_c3:
|
||
if st.button("↖️ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True):
|
||
single_data = seq.copy()
|
||
single_data["prompt_history"] = data.get("prompt_history", [])
|
||
single_data["history_tree"] = data.get("history_tree", {})
|
||
if "sequence_number" in single_data: del single_data["sequence_number"]
|
||
save_json(file_path, single_data)
|
||
st.toast("Converted to Single!", icon="✅")
|
||
st.rerun()
|
||
|
||
# 4. Remove
|
||
with act_c4:
|
||
if st.button("🗑️", key=f"{prefix}_del", use_container_width=True):
|
||
batch_list.pop(i)
|
||
data["batch_data"] = batch_list
|
||
save_json(file_path, data)
|
||
st.rerun()
|
||
|
||
st.markdown("---")
|
||
c1, c2 = st.columns([2, 1])
|
||
with c1:
|
||
seq["general_prompt"] = st.text_area("General Prompt", value=seq.get("general_prompt", ""), height=60, key=f"{prefix}_gp")
|
||
seq["general_negative"] = st.text_area("General Negative", value=seq.get("general_negative", ""), height=60, key=f"{prefix}_gn")
|
||
seq["current_prompt"] = st.text_area("Specific Prompt", value=seq.get("current_prompt", ""), height=100, key=f"{prefix}_sp")
|
||
seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn")
|
||
|
||
with c2:
|
||
seq["sequence_number"] = st.number_input("Sequence Number", value=int(seq_num), key=f"{prefix}_sn_val")
|
||
|
||
s_row1, s_row2 = st.columns([3, 1])
|
||
seed_key = f"{prefix}_seed"
|
||
with s_row2:
|
||
st.write("")
|
||
st.write("")
|
||
if st.button("🎲", key=f"{prefix}_rand"):
|
||
st.session_state[seed_key] = random.randint(0, 999999999999)
|
||
st.rerun()
|
||
with s_row1:
|
||
current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0)))
|
||
val = st.number_input("Seed", value=current_seed, key=seed_key)
|
||
seq["seed"] = val
|
||
|
||
seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam")
|
||
seq["flf"] = st.text_input("FLF", value=str(seq.get("flf", DEFAULTS["flf"])), key=f"{prefix}_flf")
|
||
|
||
if "video file path" in seq or "vace" in selected_file_name:
|
||
seq["video file path"] = st.text_input("Video File Path", value=seq.get("video file path", ""), key=f"{prefix}_vid")
|
||
with st.expander("VACE Settings"):
|
||
seq["frame_to_skip"] = st.number_input("Frame to Skip", value=int(seq.get("frame_to_skip", 81)), key=f"{prefix}_fts")
|
||
seq["input_a_frames"] = st.number_input("Input A Frames", value=int(seq.get("input_a_frames", 0)), key=f"{prefix}_ia")
|
||
seq["input_b_frames"] = st.number_input("Input B Frames", value=int(seq.get("input_b_frames", 0)), key=f"{prefix}_ib")
|
||
seq["reference switch"] = st.number_input("Reference Switch", value=int(seq.get("reference switch", 1)), key=f"{prefix}_rsw")
|
||
seq["vace schedule"] = st.number_input("VACE Schedule", value=int(seq.get("vace schedule", 1)), key=f"{prefix}_vsc")
|
||
seq["reference path"] = st.text_input("Reference Path", value=seq.get("reference path", ""), key=f"{prefix}_rp")
|
||
seq["reference image path"] = st.text_input("Reference Image Path", value=seq.get("reference image path", ""), key=f"{prefix}_rip")
|
||
|
||
if "i2v" in selected_file_name and "vace" not in selected_file_name:
|
||
seq["reference image path"] = st.text_input("Reference Image Path", value=seq.get("reference image path", ""), key=f"{prefix}_ri2")
|
||
seq["flf image path"] = st.text_input("FLF Image Path", value=seq.get("flf image path", ""), key=f"{prefix}_flfi")
|
||
|
||
# --- UPDATED: LoRA Settings with Tag Wrapping ---
|
||
with st.expander("💊 LoRA Settings"):
|
||
lc1, lc2, lc3 = st.columns(3)
|
||
|
||
# Helper to render the tag wrapper UI
|
||
def render_lora_col(col_obj, lora_idx):
|
||
with col_obj:
|
||
st.caption(f"**LoRA {lora_idx}**")
|
||
|
||
# --- HIGH ---
|
||
k_high = f"lora {lora_idx} high"
|
||
raw_h = str(seq.get(k_high, ""))
|
||
# Strip tags for display
|
||
disp_h = raw_h.replace("<lora:", "").replace(">", "")
|
||
|
||
st.write("High:")
|
||
rh1, rh2, rh3 = st.columns([0.25, 1, 0.1])
|
||
rh1.markdown("<div style='text-align: right; padding-top: 8px;'><code><lora:</code></div>", unsafe_allow_html=True)
|
||
val_h = rh2.text_input(f"L{lora_idx}H", value=disp_h, key=f"{prefix}_l{lora_idx}h", label_visibility="collapsed")
|
||
rh3.markdown("<div style='padding-top: 8px;'><code>></code></div>", unsafe_allow_html=True)
|
||
|
||
if val_h:
|
||
seq[k_high] = f"<lora:{val_h}>"
|
||
else:
|
||
seq[k_high] = ""
|
||
|
||
# --- LOW ---
|
||
k_low = f"lora {lora_idx} low"
|
||
raw_l = str(seq.get(k_low, ""))
|
||
# Strip tags for display
|
||
disp_l = raw_l.replace("<lora:", "").replace(">", "")
|
||
|
||
st.write("Low:")
|
||
rl1, rl2, rl3 = st.columns([0.25, 1, 0.1])
|
||
rl1.markdown("<div style='text-align: right; padding-top: 8px;'><code><lora:</code></div>", unsafe_allow_html=True)
|
||
val_l = rl2.text_input(f"L{lora_idx}L", value=disp_l, key=f"{prefix}_l{lora_idx}l", label_visibility="collapsed")
|
||
rl3.markdown("<div style='padding-top: 8px;'><code>></code></div>", unsafe_allow_html=True)
|
||
|
||
if val_l:
|
||
seq[k_low] = f"<lora:{val_l}>"
|
||
else:
|
||
seq[k_low] = ""
|
||
|
||
render_lora_col(lc1, 1)
|
||
render_lora_col(lc2, 2)
|
||
render_lora_col(lc3, 3)
|
||
|
||
# --- CUSTOM PARAMETERS ---
|
||
st.markdown("---")
|
||
st.caption("🔧 Custom Parameters")
|
||
|
||
custom_keys = [k for k in seq.keys() if k not in standard_keys]
|
||
keys_to_remove = []
|
||
|
||
if custom_keys:
|
||
for k in custom_keys:
|
||
ck1, ck2, ck3 = st.columns([1, 2, 0.5])
|
||
ck1.text_input("Key", value=k, disabled=True, key=f"{prefix}_ck_lbl_{k}", label_visibility="collapsed")
|
||
val = ck2.text_input("Value", value=str(seq[k]), key=f"{prefix}_cv_{k}", label_visibility="collapsed")
|
||
seq[k] = val
|
||
|
||
if ck3.button("🗑️", key=f"{prefix}_cdel_{k}"):
|
||
keys_to_remove.append(k)
|
||
|
||
with st.expander("➕ Add Parameter"):
|
||
nk_col, nv_col = st.columns(2)
|
||
new_k = nk_col.text_input("Key", key=f"{prefix}_new_k")
|
||
new_v = nv_col.text_input("Value", key=f"{prefix}_new_v")
|
||
|
||
if st.button("Add", key=f"{prefix}_add_cust"):
|
||
if new_k and new_k not in seq:
|
||
seq[new_k] = new_v
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.rerun()
|
||
|
||
if keys_to_remove:
|
||
for k in keys_to_remove:
|
||
del seq[k]
|
||
save_json(file_path, data)
|
||
st.session_state.ui_reset_token += 1
|
||
st.rerun()
|
||
|
||
st.markdown("---")
|
||
|
||
# --- SAVE ACTIONS WITH HISTORY COMMIT ---
|
||
col_save, col_note = st.columns([1, 2])
|
||
|
||
with col_note:
|
||
commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3")
|
||
|
||
with col_save:
|
||
if st.button("💾 Save & Snap", use_container_width=True):
|
||
data["batch_data"] = batch_list
|
||
|
||
tree_data = data.get("history_tree", {})
|
||
htree = HistoryTree(tree_data)
|
||
|
||
snapshot_payload = copy.deepcopy(data)
|
||
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
|
||
|
||
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
||
|
||
data["history_tree"] = htree.to_dict()
|
||
save_json(file_path, data)
|
||
|
||
if 'restored_indicator' in st.session_state:
|
||
del st.session_state.restored_indicator
|
||
|
||
st.toast("Batch Saved & Snapshot Created!", icon="🚀")
|
||
st.rerun() |