Update tab_batch.py
This commit is contained in:
61
tab_batch.py
61
tab_batch.py
@@ -1,6 +1,7 @@
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
import random
|
import random
|
||||||
from utils import DEFAULTS, save_json, load_json
|
from utils import DEFAULTS, save_json, load_json
|
||||||
|
from history_tree import HistoryTree # <--- NEW IMPORT
|
||||||
|
|
||||||
def create_batch_callback(original_filename, current_data, current_dir):
|
def create_batch_callback(original_filename, current_data, current_dir):
|
||||||
new_name = f"batch_{original_filename}"
|
new_name = f"batch_{original_filename}"
|
||||||
@@ -12,11 +13,15 @@ def create_batch_callback(original_filename, current_data, current_dir):
|
|||||||
|
|
||||||
first_item = current_data.copy()
|
first_item = current_data.copy()
|
||||||
if "prompt_history" in first_item: del first_item["prompt_history"]
|
if "prompt_history" in first_item: del first_item["prompt_history"]
|
||||||
|
if "history_tree" in first_item: del first_item["history_tree"] # Don't duplicate tree
|
||||||
|
|
||||||
first_item["sequence_number"] = 1
|
first_item["sequence_number"] = 1
|
||||||
|
|
||||||
new_data = {
|
new_data = {
|
||||||
"batch_data": [first_item],
|
"batch_data": [first_item],
|
||||||
"prompt_history": current_data.get("prompt_history", [])
|
# Initialize empty history for the new file
|
||||||
|
"history_tree": {},
|
||||||
|
"prompt_history": []
|
||||||
}
|
}
|
||||||
|
|
||||||
save_json(new_path, new_data)
|
save_json(new_path, new_data)
|
||||||
@@ -34,6 +39,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
|
|
||||||
batch_list = data.get("batch_data", [])
|
batch_list = data.get("batch_data", [])
|
||||||
|
|
||||||
|
# --- ADD NEW SEQUENCE AREA ---
|
||||||
st.subheader("Add New Sequence")
|
st.subheader("Add New Sequence")
|
||||||
ac1, ac2 = st.columns(2)
|
ac1, ac2 = st.columns(2)
|
||||||
|
|
||||||
@@ -44,6 +50,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
src_data, _ = load_json(current_dir / src_name)
|
src_data, _ = load_json(current_dir / src_name)
|
||||||
|
|
||||||
with ac2:
|
with ac2:
|
||||||
|
# Legacy history support for import source
|
||||||
src_hist = src_data.get("prompt_history", [])
|
src_hist = src_data.get("prompt_history", [])
|
||||||
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
h_opts = [f"#{i+1}: {h.get('note', 'No Note')} ({h.get('prompt', '')[:15]}...)" for i, h in enumerate(src_hist)] if src_hist else []
|
||||||
sel_hist = st.selectbox("History Entry:", h_opts, key="batch_src_hist")
|
sel_hist = st.selectbox("History Entry:", h_opts, key="batch_src_hist")
|
||||||
@@ -55,7 +62,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
for s in batch_list:
|
for s in batch_list:
|
||||||
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
if "sequence_number" in s: max_seq = max(max_seq, int(s["sequence_number"]))
|
||||||
new_item["sequence_number"] = max_seq + 1
|
new_item["sequence_number"] = max_seq + 1
|
||||||
for k in ["prompt_history", "note", "loras"]:
|
|
||||||
|
# Cleanup metadata keys from item
|
||||||
|
for k in ["prompt_history", "history_tree", "note", "loras"]:
|
||||||
if k in new_item: del new_item[k]
|
if k in new_item: del new_item[k]
|
||||||
|
|
||||||
batch_list.append(new_item)
|
batch_list.append(new_item)
|
||||||
@@ -73,7 +82,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
item.update(flat)
|
item.update(flat)
|
||||||
add_sequence(item)
|
add_sequence(item)
|
||||||
|
|
||||||
if bc3.button("➕ From History", use_container_width=True, disabled=not src_hist):
|
if bc3.button("➕ From History (Legacy)", use_container_width=True, disabled=not src_hist):
|
||||||
if sel_hist:
|
if sel_hist:
|
||||||
idx = int(sel_hist.split(":")[0].replace("#", "")) - 1
|
idx = int(sel_hist.split(":")[0].replace("#", "")) - 1
|
||||||
item = DEFAULTS.copy()
|
item = DEFAULTS.copy()
|
||||||
@@ -83,6 +92,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
item.update(h_item["loras"])
|
item.update(h_item["loras"])
|
||||||
add_sequence(item)
|
add_sequence(item)
|
||||||
|
|
||||||
|
# --- RENDER LIST ---
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
st.info(f"Batch contains {len(batch_list)} sequences.")
|
st.info(f"Batch contains {len(batch_list)} sequences.")
|
||||||
|
|
||||||
@@ -110,7 +120,10 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
flat = src_data["batch_data"][0] if "batch_data" in src_data and src_data["batch_data"] else src_data
|
||||||
item.update(flat)
|
item.update(flat)
|
||||||
item["sequence_number"] = seq_num
|
item["sequence_number"] = seq_num
|
||||||
if "prompt_history" in item: del item["prompt_history"]
|
|
||||||
|
for k in ["prompt_history", "history_tree"]:
|
||||||
|
if k in item: del item[k]
|
||||||
|
|
||||||
batch_list[i] = item
|
batch_list[i] = item
|
||||||
data["batch_data"] = batch_list
|
data["batch_data"] = batch_list
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
@@ -121,6 +134,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
if b2.button("↖️ Promote to Single", key=f"{prefix}_prom"):
|
if b2.button("↖️ Promote to Single", key=f"{prefix}_prom"):
|
||||||
single_data = seq.copy()
|
single_data = seq.copy()
|
||||||
single_data["prompt_history"] = data.get("prompt_history", [])
|
single_data["prompt_history"] = data.get("prompt_history", [])
|
||||||
|
single_data["history_tree"] = data.get("history_tree", {})
|
||||||
if "sequence_number" in single_data: del single_data["sequence_number"]
|
if "sequence_number" in single_data: del single_data["sequence_number"]
|
||||||
save_json(file_path, single_data)
|
save_json(file_path, single_data)
|
||||||
st.toast("Converted to Single!", icon="✅")
|
st.toast("Converted to Single!", icon="✅")
|
||||||
@@ -152,7 +166,9 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
st.session_state[seed_key] = random.randint(0, 999999999999)
|
st.session_state[seed_key] = random.randint(0, 999999999999)
|
||||||
st.rerun()
|
st.rerun()
|
||||||
with s_row1:
|
with s_row1:
|
||||||
val = st.number_input("Seed", value=int(seq.get("seed", 0)), key=seed_key)
|
# Prefer session state if dice was clicked
|
||||||
|
current_seed = st.session_state.get(seed_key, int(seq.get("seed", 0)))
|
||||||
|
val = st.number_input("Seed", value=current_seed, key=seed_key)
|
||||||
seq["seed"] = val
|
seq["seed"] = val
|
||||||
|
|
||||||
seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam")
|
seq["camera"] = st.text_input("Camera", value=seq.get("camera", ""), key=f"{prefix}_cam")
|
||||||
@@ -204,6 +220,7 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
if st.button("Add", key=f"{prefix}_add_cust"):
|
if st.button("Add", key=f"{prefix}_add_cust"):
|
||||||
if new_k and new_k not in seq:
|
if new_k and new_k not in seq:
|
||||||
seq[new_k] = new_v
|
seq[new_k] = new_v
|
||||||
|
# Autosave on structure change
|
||||||
save_json(file_path, data)
|
save_json(file_path, data)
|
||||||
st.session_state.ui_reset_token += 1
|
st.session_state.ui_reset_token += 1
|
||||||
st.rerun()
|
st.rerun()
|
||||||
@@ -216,7 +233,33 @@ def render_batch_processor(data, file_path, json_files, current_dir, selected_fi
|
|||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
st.markdown("---")
|
st.markdown("---")
|
||||||
if st.button("💾 Save Batch Changes"):
|
|
||||||
data["batch_data"] = batch_list
|
# --- SAVE ACTIONS WITH HISTORY COMMIT ---
|
||||||
save_json(file_path, data)
|
col_save, col_note = st.columns([1, 2])
|
||||||
st.toast("Batch saved!", icon="🚀")
|
|
||||||
|
with col_note:
|
||||||
|
commit_msg = st.text_input("Change Note (Optional)", placeholder="e.g. Added sequence 3")
|
||||||
|
|
||||||
|
with col_save:
|
||||||
|
if st.button("💾 Save & Snap", use_container_width=True):
|
||||||
|
# 1. Update Data with the latest list
|
||||||
|
data["batch_data"] = batch_list
|
||||||
|
|
||||||
|
# 2. Init/Load History Engine
|
||||||
|
tree_data = data.get("history_tree", {})
|
||||||
|
htree = HistoryTree(tree_data)
|
||||||
|
|
||||||
|
# 3. Create Clean Snapshot
|
||||||
|
# Remove the recursive history tree from the payload we are saving INTO the history tree
|
||||||
|
snapshot_payload = data.copy()
|
||||||
|
if "history_tree" in snapshot_payload: del snapshot_payload["history_tree"]
|
||||||
|
|
||||||
|
# 4. Commit
|
||||||
|
htree.commit(snapshot_payload, note=commit_msg if commit_msg else "Batch Update")
|
||||||
|
|
||||||
|
# 5. Write Tree back to Data
|
||||||
|
data["history_tree"] = htree.to_dict()
|
||||||
|
|
||||||
|
# 6. Disk Save
|
||||||
|
save_json(file_path, data)
|
||||||
|
st.toast("Batch Saved & Snapshot Created!", icon="🚀")
|
||||||
|
|||||||
Reference in New Issue
Block a user