import streamlit as st import random import copy from pathlib import Path from utils import DEFAULTS, save_json, load_json, KEY_BATCH_DATA, KEY_HISTORY_TREE, KEY_PROMPT_HISTORY, KEY_SEQUENCE_NUMBER from history_tree import HistoryTree IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} SUB_SEGMENT_MULTIPLIER = 1000 def is_subsegment(seq_num): """Return True if seq_num is a sub-segment (>= 1000).""" return int(seq_num) >= SUB_SEGMENT_MULTIPLIER def parent_of(seq_num): """Return the parent segment number (or self if already a parent).""" seq_num = int(seq_num) return seq_num // SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else seq_num def sub_index_of(seq_num): """Return the sub-index (0 if parent).""" seq_num = int(seq_num) return seq_num % SUB_SEGMENT_MULTIPLIER if is_subsegment(seq_num) else 0 def format_seq_label(seq_num): """Return display label: 'Sequence #3' or 'Sub #2.1'.""" seq_num = int(seq_num) if is_subsegment(seq_num): return f"Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)}" return f"Sequence #{seq_num}" def next_sub_segment_number(batch_list, parent_seq_num): """Find the next available sub-segment number under a parent.""" parent_seq_num = int(parent_seq_num) max_sub = 0 for s in batch_list: sn = int(s.get(KEY_SEQUENCE_NUMBER, 0)) if is_subsegment(sn) and parent_of(sn) == parent_seq_num: max_sub = max(max_sub, sub_index_of(sn)) return parent_seq_num * SUB_SEGMENT_MULTIPLIER + max_sub + 1 def find_insert_position(batch_list, parent_index, parent_seq_num): """Find the insert position after the parent's last existing sub-segment.""" parent_seq_num = int(parent_seq_num) pos = parent_index + 1 while pos < len(batch_list): sn = int(batch_list[pos].get(KEY_SEQUENCE_NUMBER, 0)) if is_subsegment(sn) and parent_of(sn) == parent_seq_num: pos += 1 else: break return pos def _render_mass_update(batch_list, data, file_path, key_prefix): """Render the mass update UI section.""" with st.expander("🔄 Mass Update", expanded=False): if len(batch_list) < 2: st.info("Need at least 2 sequences for mass update.") return # Source sequence selector source_idx = st.selectbox( "Copy from sequence:", range(len(batch_list)), format_func=lambda i: format_seq_label(batch_list[i].get('sequence_number', i+1)), key=f"{key_prefix}_mass_src" ) source_seq = batch_list[source_idx] # Field multi-select (exclude sequence_number) available_keys = [k for k in source_seq.keys() if k != "sequence_number"] selected_keys = st.multiselect("Fields to copy:", available_keys, key=f"{key_prefix}_mass_fields") if not selected_keys: return # Target sequence checkboxes st.write("Apply to:") select_all = st.checkbox("Select All", key=f"{key_prefix}_mass_all") target_indices = [] target_cols = st.columns(min(4, len(batch_list) - 1)) if len(batch_list) > 1 else [st] col_idx = 0 for i, seq in enumerate(batch_list): if i == source_idx: continue seq_num = seq.get("sequence_number", i + 1) with target_cols[col_idx % len(target_cols)]: checked = select_all or st.checkbox(format_seq_label(seq_num), key=f"{key_prefix}_mass_t{i}") if checked: target_indices.append(i) col_idx += 1 # Preview if target_indices and selected_keys: with st.expander("Preview changes", expanded=True): for key in selected_keys: val = source_seq.get(key, "") display_val = str(val)[:100] + "..." if len(str(val)) > 100 else str(val) st.caption(f"**{key}**: {display_val}") # Apply button if st.button("Apply Changes", type="primary", key=f"{key_prefix}_mass_apply"): for i in target_indices: for key in selected_keys: batch_list[i][key] = copy.deepcopy(source_seq.get(key)) # Save with history snapshot data[KEY_BATCH_DATA] = batch_list htree = HistoryTree(data.get(KEY_HISTORY_TREE, {})) snapshot_payload = copy.deepcopy(data) if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE] htree.commit(snapshot_payload, f"Mass update: {', '.join(selected_keys)}") data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) st.session_state.data_cache = data st.session_state.ui_reset_token += 1 st.toast(f"Updated {len(target_indices)} sequences", icon="✅") st.rerun() def create_batch_callback(original_filename, current_data, current_dir): new_name = f"batch_{original_filename}" new_path = current_dir / new_name if new_path.exists(): st.toast(f"File {new_name} already exists!", icon="âš ī¸") return first_item = current_data.copy() if KEY_PROMPT_HISTORY in first_item: del first_item[KEY_PROMPT_HISTORY] if KEY_HISTORY_TREE in first_item: del first_item[KEY_HISTORY_TREE] first_item[KEY_SEQUENCE_NUMBER] = 1 new_data = { KEY_BATCH_DATA: [first_item], KEY_HISTORY_TREE: {}, KEY_PROMPT_HISTORY: [] } save_json(new_path, new_data) st.toast(f"Created {new_name}", icon="✨") st.session_state.file_selector = new_name def render_batch_processor(data, file_path, json_files, current_dir, selected_file_name): is_batch_file = KEY_BATCH_DATA in data or isinstance(data, list) if not is_batch_file: st.warning("This is a Single file. To use Batch mode, create a copy.") st.button("✨ Create Batch Copy", on_click=create_batch_callback, args=(selected_file_name, data, current_dir)) return if 'restored_indicator' in st.session_state and st.session_state.restored_indicator: st.info(f"📍 Editing Restored Version: **{st.session_state.restored_indicator}**") batch_list = data.get(KEY_BATCH_DATA, []) # --- ADD NEW SEQUENCE AREA --- st.subheader("Add New Sequence") ac1, ac2 = st.columns(2) with ac1: file_options = [f.name for f in json_files] d_idx = file_options.index(selected_file_name) if selected_file_name in file_options else 0 src_name = st.selectbox("Source File:", file_options, index=d_idx, key="batch_src_file") src_data, _ = load_json(current_dir / src_name) with ac2: src_batch = src_data.get(KEY_BATCH_DATA, []) if src_batch: seq_opts = list(range(len(src_batch))) sel_seq_idx = st.selectbox( "Source Sequence:", seq_opts, format_func=lambda i: format_seq_label(src_batch[i].get(KEY_SEQUENCE_NUMBER, i + 1)), key="batch_src_seq" ) else: st.caption("Single file (no sequences)") sel_seq_idx = None bc1, bc2 = st.columns(2) def add_sequence(new_item): max_seq = 0 for s in batch_list: sn = int(s.get(KEY_SEQUENCE_NUMBER, 0)) if not is_subsegment(sn): max_seq = max(max_seq, sn) new_item[KEY_SEQUENCE_NUMBER] = max_seq + 1 for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, "note", "loras"]: if k in new_item: del new_item[k] batch_list.append(new_item) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() if bc1.button("➕ Add Empty", use_container_width=True): add_sequence(DEFAULTS.copy()) if bc2.button("➕ From Source", use_container_width=True, help=f"Import from {src_name}"): item = DEFAULTS.copy() if src_batch and sel_seq_idx is not None: item.update(src_batch[sel_seq_idx]) else: item.update(src_data) add_sequence(item) # --- RENDER LIST --- st.markdown("---") info_col, reorder_col = st.columns([3, 1]) info_col.info(f"Batch contains {len(batch_list)} sequences.") if reorder_col.button("đŸ”ĸ Sort by Number", use_container_width=True, help="Reorder sequences by sequence number"): batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0))) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Sorted by sequence number!", icon="đŸ”ĸ") st.rerun() # --- MASS UPDATE SECTION --- ui_reset_token = st.session_state.get("ui_reset_token", 0) _render_mass_update(batch_list, data, file_path, f"{selected_file_name}_v{ui_reset_token}") # Updated LoRA keys to match new logic lora_keys = ["lora 1 high", "lora 1 low", "lora 2 high", "lora 2 low", "lora 3 high", "lora 3 low"] standard_keys = { "general_prompt", "general_negative", "current_prompt", "negative", "prompt", "seed", "cfg", "camera", "flf", KEY_SEQUENCE_NUMBER } standard_keys.update(lora_keys) standard_keys.update([ "frame_to_skip", "end_frame", "transition", "vace_length", "input_a_frames", "input_b_frames", "reference switch", "vace schedule", "reference path", "video file path", "reference image path", "flf image path" ]) VACE_MODES = [ "End Extend", "Pre Extend", "Middle Extend", "Edge Extend", "Join Extend", "Bidirectional Extend", "Frame Interpolation", "Replace/Inpaint", "Video Inpaint", "Keyframe", ] VACE_FORMULAS = [ "base + A", # 0 End Extend "base + B", # 1 Pre Extend "base + A + B", # 2 Middle Extend "base + A + B", # 3 Edge Extend "base + A + B", # 4 Join Extend "base + A + B", # 5 Bidirectional "(B-1) * step", # 6 Frame Interpolation "snap(source)", # 7 Replace/Inpaint "snap(source)", # 8 Video Inpaint "base + A + B", # 9 Keyframe ] for i, seq in enumerate(batch_list): seq_num = seq.get(KEY_SEQUENCE_NUMBER, i+1) prefix = f"{selected_file_name}_seq{i}_v{st.session_state.ui_reset_token}" if is_subsegment(seq_num): expander_label = f"🔗 â†ŗ Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)} ({int(seq_num)})" else: expander_label = f"đŸŽŦ Sequence #{seq_num}" with st.expander(expander_label, expanded=False): # --- ACTION ROW --- act_c1, act_c2, act_c3, act_c4 = st.columns([1.2, 1.8, 1.2, 0.5]) # 1. Copy Source with act_c1: if st.button(f"đŸ“Ĩ Copy {src_name}", key=f"{prefix}_copy", use_container_width=True): item = DEFAULTS.copy() if src_batch and sel_seq_idx is not None: item.update(src_batch[sel_seq_idx]) else: item.update(src_data) item[KEY_SEQUENCE_NUMBER] = seq_num for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE]: if k in item: del item[k] batch_list[i] = item data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Copied!", icon="đŸ“Ĩ") st.rerun() # 2. Cloning Tools with act_c2: cl_1, cl_2, cl_3 = st.columns(3) if cl_1.button("đŸ‘¯ Next", key=f"{prefix}_c_next", help="Clone and insert below", use_container_width=True): new_seq = copy.deepcopy(seq) max_sn = 0 for s in batch_list: sn = int(s.get(KEY_SEQUENCE_NUMBER, 0)) if not is_subsegment(sn): max_sn = max(max_sn, sn) new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1 if not is_subsegment(seq_num): insert_pos = find_insert_position(batch_list, i, int(seq_num)) else: insert_pos = i + 1 batch_list.insert(insert_pos, new_seq) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Cloned to Next!", icon="đŸ‘¯") st.rerun() if cl_2.button("âŦ End", key=f"{prefix}_c_end", help="Clone and add to bottom", use_container_width=True): new_seq = copy.deepcopy(seq) max_sn = 0 for s in batch_list: sn = int(s.get(KEY_SEQUENCE_NUMBER, 0)) if not is_subsegment(sn): max_sn = max(max_sn, sn) new_seq[KEY_SEQUENCE_NUMBER] = max_sn + 1 batch_list.append(new_seq) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast("Cloned to End!", icon="âŦ") st.rerun() if cl_3.button("🔗 Sub", key=f"{prefix}_c_sub", help="Clone as sub-segment", use_container_width=True): new_seq = copy.deepcopy(seq) p_seq_num = parent_of(seq_num) # Find the parent's index in batch_list p_idx = i if is_subsegment(seq_num): for pi, ps in enumerate(batch_list): if int(ps.get(KEY_SEQUENCE_NUMBER, 0)) == p_seq_num: p_idx = pi break new_seq[KEY_SEQUENCE_NUMBER] = next_sub_segment_number(batch_list, p_seq_num) insert_pos = find_insert_position(batch_list, p_idx, p_seq_num) batch_list.insert(insert_pos, new_seq) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast(f"Created {format_seq_label(new_seq[KEY_SEQUENCE_NUMBER])}!", icon="🔗") st.rerun() # 3. Promote with act_c3: if st.button("â†–ī¸ Promote", key=f"{prefix}_prom", help="Save as Single File", use_container_width=True): single_data = seq.copy() single_data[KEY_PROMPT_HISTORY] = data.get(KEY_PROMPT_HISTORY, []) single_data[KEY_HISTORY_TREE] = data.get(KEY_HISTORY_TREE, {}) if KEY_SEQUENCE_NUMBER in single_data: del single_data[KEY_SEQUENCE_NUMBER] save_json(file_path, single_data) st.session_state.data_cache = single_data st.session_state.ui_reset_token += 1 st.toast("Converted to Single!", icon="✅") st.rerun() # 4. Remove with act_c4: if st.button("đŸ—‘ī¸", key=f"{prefix}_del", use_container_width=True): batch_list.pop(i) data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() st.markdown("---") c1, c2 = st.columns([2, 1]) with c1: seq["general_prompt"] = st.text_area("General Prompt", value=seq.get("general_prompt", ""), height=60, key=f"{prefix}_gp") seq["general_negative"] = st.text_area("General Negative", value=seq.get("general_negative", ""), height=60, key=f"{prefix}_gn") seq["current_prompt"] = st.text_area("Specific Prompt", value=seq.get("current_prompt", ""), height=300, key=f"{prefix}_sp") seq["negative"] = st.text_area("Specific Negative", value=seq.get("negative", ""), height=60, key=f"{prefix}_sn") with c2: sn_label = f"Sequence Number (â†ŗ Sub #{parent_of(seq_num)}.{sub_index_of(seq_num)})" if is_subsegment(seq_num) else "Sequence Number" seq[KEY_SEQUENCE_NUMBER] = st.number_input(sn_label, value=int(seq_num), key=f"{prefix}_sn_val") s_row1, s_row2 = st.columns([3, 1]) seed_key = f"{prefix}_seed" with s_row2: st.write("") st.write("") if st.button("🎲", key=f"{prefix}_rand"): st.session_state[seed_key] = random.randint(0, 999999999999) st.rerun() with s_row1: current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0))) val = st.number_input("Seed", value=current_seed, key=seed_key) seq["seed"] = val seq["cfg"] = st.number_input("CFG", value=float(seq.get("cfg", DEFAULTS["cfg"])), step=0.5, format="%.1f", key=f"{prefix}_cfg") seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam") seq["flf"] = st.text_input("FLF", value=str(seq.get("flf", DEFAULTS["flf"])), key=f"{prefix}_flf") seq["end_frame"] = st.number_input("End Frame", value=int(seq.get("end_frame", 0)), key=f"{prefix}_ef") seq["video file path"] = st.text_input("Video File Path", value=seq.get("video file path", ""), key=f"{prefix}_vid") for img_label, img_key, img_suffix in [ ("Reference Image Path", "reference image path", "rip"), ("Reference Path", "reference path", "rp"), ("FLF Image Path", "flf image path", "flfi"), ]: img_col, prev_col = st.columns([5, 1]) seq[img_key] = img_col.text_input(img_label, value=seq.get(img_key, ""), key=f"{prefix}_{img_suffix}") img_path = Path(seq[img_key]) if seq[img_key] else None if img_path and img_path.exists() and img_path.suffix.lower() in IMAGE_EXTENSIONS: with prev_col.popover("👁"): st.image(str(img_path), use_container_width=True) with st.expander("VACE Settings"): fts_col, fts_btn = st.columns([3, 1]) saved_fts_key = f"{prefix}_fts_saved" if saved_fts_key not in st.session_state: st.session_state[saved_fts_key] = int(seq.get("frame_to_skip", 81)) old_fts = st.session_state[saved_fts_key] seq["frame_to_skip"] = fts_col.number_input("Frame to Skip", value=old_fts, key=f"{prefix}_fts") delta = int(seq["frame_to_skip"]) - old_fts delta_label = f"Shift ↓ ({delta:+d})" if delta != 0 else "Shift ↓ (0)" fts_btn.write("") fts_btn.write("") if fts_btn.button(delta_label, key=f"{prefix}_fts_shift", help="Apply delta to all following sequences", disabled=(delta == 0)): if delta != 0: shifted = 0 for j in range(i + 1, len(batch_list)): batch_list[j]["frame_to_skip"] = int(batch_list[j].get("frame_to_skip", 81)) + delta shifted += 1 data[KEY_BATCH_DATA] = batch_list save_json(file_path, data) st.session_state.ui_reset_token += 1 st.toast(f"Shifted {shifted} sequences by {delta:+d}", icon="âŦ") st.rerun() else: st.toast("No change to shift", icon="â„šī¸") seq["transition"] = st.text_input("Transition", value=str(seq.get("transition", "1-2")), key=f"{prefix}_trans") vs_col, vs_label = st.columns([3, 1]) sched_val = int(seq.get("vace schedule", 1)) seq["vace schedule"] = vs_col.number_input("VACE Schedule", value=sched_val, min_value=0, max_value=len(VACE_MODES) - 1, key=f"{prefix}_vsc") mode_idx = int(seq["vace schedule"]) vs_label.write("") vs_label.write("") vs_label.caption(VACE_MODES[mode_idx]) with st.popover("📋 Mode Reference"): st.markdown( "| # | Mode | Formula |\n" "|:--|:-----|:--------|\n" + "\n".join( f"| **{j}** | {VACE_MODES[j]} | `{VACE_FORMULAS[j]}` |" for j in range(len(VACE_MODES)) ) + "\n\n*All totals snapped to 4n+1 (1,5,9,â€Ļ,49,â€Ļ,81,â€Ļ)*" ) seq["input_a_frames"] = st.number_input("Input A Frames", value=int(seq.get("input_a_frames", 16)), key=f"{prefix}_ia") seq["input_b_frames"] = st.number_input("Input B Frames", value=int(seq.get("input_b_frames", 16)), key=f"{prefix}_ib") input_a = int(seq.get("input_a_frames", 16)) input_b = int(seq.get("input_b_frames", 16)) stored_total = int(seq.get("vace_length", 49)) # Reverse using same mode formula that was used to store if mode_idx == 0: base_length = max(stored_total - input_a, 1) elif mode_idx == 1: base_length = max(stored_total - input_b, 1) else: base_length = max(stored_total - input_a - input_b, 1) vl_col, vl_out = st.columns([3, 1]) new_base = vl_col.number_input("VACE Length", value=base_length, min_value=1, key=f"{prefix}_vl") if mode_idx == 0: # End Extend: base + A raw_total = new_base + input_a elif mode_idx == 1: # Pre Extend: base + B raw_total = new_base + input_b else: # Most modes: base + A + B raw_total = new_base + input_a + input_b # Snap to 4n+1 (1,5,9,13,...,81,...) to match VACE sampler seq["vace_length"] = ((raw_total + 2) // 4) * 4 + 1 vl_out.metric("Output", seq["vace_length"]) seq["reference switch"] = st.number_input("Reference Switch", value=int(seq.get("reference switch", 1)), key=f"{prefix}_rsw") # --- UPDATED: LoRA Settings with Tag Wrapping --- with st.expander("💊 LoRA Settings"): lc1, lc2, lc3 = st.columns(3) # Helper to render the tag wrapper UI def render_lora_col(col_obj, lora_idx): with col_obj: st.caption(f"**LoRA {lora_idx}**") # --- HIGH --- k_high = f"lora {lora_idx} high" raw_h = str(seq.get(k_high, "")) # Strip tags for display disp_h = raw_h.replace("", "") st.write("High:") rh1, rh2, rh3 = st.columns([0.25, 1, 0.1]) rh1.markdown("
<lora:
", unsafe_allow_html=True) val_h = rh2.text_input(f"L{lora_idx}H", value=disp_h, key=f"{prefix}_l{lora_idx}h", label_visibility="collapsed") rh3.markdown("
>
", unsafe_allow_html=True) if val_h: seq[k_high] = f"" else: seq[k_high] = "" # --- LOW --- k_low = f"lora {lora_idx} low" raw_l = str(seq.get(k_low, "")) # Strip tags for display disp_l = raw_l.replace("", "") st.write("Low:") rl1, rl2, rl3 = st.columns([0.25, 1, 0.1]) rl1.markdown("
<lora:
", unsafe_allow_html=True) val_l = rl2.text_input(f"L{lora_idx}L", value=disp_l, key=f"{prefix}_l{lora_idx}l", label_visibility="collapsed") rl3.markdown("
>
", unsafe_allow_html=True) if val_l: seq[k_low] = f"" else: seq[k_low] = "" render_lora_col(lc1, 1) render_lora_col(lc2, 2) render_lora_col(lc3, 3) # --- CUSTOM PARAMETERS --- st.markdown("---") st.caption("🔧 Custom Parameters") custom_keys = [k for k in seq.keys() if k not in standard_keys] keys_to_remove = [] if custom_keys: for k in custom_keys: ck1, ck2, ck3 = st.columns([1, 2, 0.5]) ck1.text_input("Key", value=k, disabled=True, key=f"{prefix}_ck_lbl_{k}", label_visibility="collapsed") val = ck2.text_input("Value", value=str(seq[k]), key=f"{prefix}_cv_{k}", label_visibility="collapsed") seq[k] = val if ck3.button("đŸ—‘ī¸", key=f"{prefix}_cdel_{k}"): keys_to_remove.append(k) with st.expander("➕ Add Parameter"): nk_col, nv_col = st.columns(2) new_k = nk_col.text_input("Key", key=f"{prefix}_new_k") new_v = nv_col.text_input("Value", key=f"{prefix}_new_v") if st.button("Add", key=f"{prefix}_add_cust"): if new_k and new_k not in seq: seq[new_k] = new_v save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() if keys_to_remove: for k in keys_to_remove: del seq[k] save_json(file_path, data) st.session_state.ui_reset_token += 1 st.rerun() st.markdown("---") # --- SAVE ACTIONS WITH HISTORY COMMIT --- col_save, col_note = st.columns([1, 2]) with col_note: commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3") with col_save: if st.button("💾 Save & Snap", use_container_width=True): data[KEY_BATCH_DATA] = batch_list tree_data = data.get(KEY_HISTORY_TREE, {}) htree = HistoryTree(tree_data) snapshot_payload = copy.deepcopy(data) if KEY_HISTORY_TREE in snapshot_payload: del snapshot_payload[KEY_HISTORY_TREE] htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update") data[KEY_HISTORY_TREE] = htree.to_dict() save_json(file_path, data) if 'restored_indicator' in st.session_state: del st.session_state.restored_indicator st.toast("Batch Saved & Snapshot Created!", icon="🚀") st.rerun()