Fix page reset on save: async I/O and avoid needless deepcopy

Root cause: save_json + sync_to_db blocked the event loop while
serializing the growing history tree (all snapshots) as JSON,
causing NiceGUI websocket timeout and client reconnect.

- Skip history tree in snapshot deepcopy (copied only to discard)
- Move save_json/sync_to_db to asyncio.to_thread in all callbacks
- Make save_and_snap, commit, sort_by_number, shift_fts,
  _add_sequence, apply_mass_update async

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-17 23:17:53 +01:00
parent f3ad3e01bc
commit 5aac1677f7
+29 -26
View File
@@ -1,3 +1,4 @@
import asyncio
import copy import copy
import json import json
import math import math
@@ -260,22 +261,22 @@ def render_batch_processor(state: AppState):
src_file_select.on_value_change(lambda _: _update_src()) src_file_select.on_value_change(lambda _: _update_src())
_update_src() _update_src()
def _add_sequence(new_item): async def _add_sequence(new_item):
new_item[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1 new_item[KEY_SEQUENCE_NUMBER] = max_main_seq_number(batch_list) + 1
for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, 'note', 'loras']: for k in [KEY_PROMPT_HISTORY, KEY_HISTORY_TREE, 'note', 'loras']:
new_item.pop(k, None) new_item.pop(k, None)
batch_list.append(new_item) batch_list.append(new_item)
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
render_sequence_list.refresh() render_sequence_list.refresh()
with ui.row().classes('q-mt-sm'): with ui.row().classes('q-mt-sm'):
def add_empty(): async def add_empty():
_add_sequence(copy.deepcopy(DEFAULTS)) await _add_sequence(copy.deepcopy(DEFAULTS))
def add_from_source(): async def add_from_source():
item = copy.deepcopy(DEFAULTS) item = copy.deepcopy(DEFAULTS)
src_batch = _src_cache['batch'] src_batch = _src_cache['batch']
sel_idx = src_seq_select.value sel_idx = src_seq_select.value
@@ -283,7 +284,7 @@ def render_batch_processor(state: AppState):
item.update(copy.deepcopy(src_batch[int(sel_idx)])) item.update(copy.deepcopy(src_batch[int(sel_idx)]))
elif _src_cache['data']: elif _src_cache['data']:
item.update(copy.deepcopy(_src_cache['data'])) item.update(copy.deepcopy(_src_cache['data']))
_add_sequence(item) await _add_sequence(item)
ui.button('Add Empty', icon='add', on_click=add_empty) ui.button('Add Empty', icon='add', on_click=add_empty)
ui.button('From Source', icon='file_download', on_click=add_from_source) ui.button('From Source', icon='file_download', on_click=add_from_source)
@@ -300,12 +301,12 @@ def render_batch_processor(state: AppState):
} }
standard_keys.update(lora_keys) standard_keys.update(lora_keys)
def sort_by_number(): async def sort_by_number():
batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0))) batch_list.sort(key=lambda s: int(s.get(KEY_SEQUENCE_NUMBER, 0)))
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
ui.notify('Sorted by sequence number!', type='positive') ui.notify('Sorted by sequence number!', type='positive')
render_sequence_list.refresh() render_sequence_list.refresh()
@@ -335,12 +336,13 @@ def render_batch_processor(state: AppState):
commit_input = ui.input('Change Note (Optional)', commit_input = ui.input('Change Note (Optional)',
placeholder='e.g. Added sequence 3').classes('col') placeholder='e.g. Added sequence 3').classes('col')
def save_and_snap(): async def save_and_snap():
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
tree_data = data.get(KEY_HISTORY_TREE, {}) tree_data = data.get(KEY_HISTORY_TREE, {})
htree = HistoryTree(tree_data) htree = HistoryTree(tree_data)
snapshot_payload = copy.deepcopy(data) # Only deepcopy the data we need (skip history tree — it's huge and gets discarded)
snapshot_payload.pop(KEY_HISTORY_TREE, None) snapshot_payload = {k: copy.deepcopy(v) for k, v in data.items()
if k != KEY_HISTORY_TREE}
note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list) note = commit_input.value if commit_input.value else _auto_change_note(htree, batch_list)
try: try:
htree.commit(snapshot_payload, note=note) htree.commit(snapshot_payload, note=note)
@@ -348,9 +350,10 @@ def render_batch_processor(state: AppState):
ui.notify(f'Save failed: {e}', type='negative') ui.notify(f'Save failed: {e}', type='negative')
return return
data[KEY_HISTORY_TREE] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) # Run heavy I/O off the event loop to prevent websocket timeout
await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
state.restored_indicator = None state.restored_indicator = None
commit_input.set_value('') commit_input.set_value('')
ui.notify('Batch Saved & Snapshot Created!', type='positive') ui.notify('Batch Saved & Snapshot Created!', type='positive')
@@ -365,11 +368,11 @@ def render_batch_processor(state: AppState):
def _render_sequence_card(i, seq, batch_list, data, file_path, state, def _render_sequence_card(i, seq, batch_list, data, file_path, state,
src_cache, src_seq_select, standard_keys, src_cache, src_seq_select, standard_keys,
refresh_list): refresh_list):
def commit(message=None): async def commit(message=None):
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
if message: if message:
ui.notify(message, type='positive') ui.notify(message, type='positive')
refresh_list.refresh() refresh_list.refresh()
@@ -647,7 +650,7 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
_original_fts = _safe_int(seq.get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT) _original_fts = _safe_int(seq.get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT)
def shift_fts(idx=i, orig=_original_fts): async def shift_fts(idx=i, orig=_original_fts):
new_fts = _safe_int(fts_input.value, orig) new_fts = _safe_int(fts_input.value, orig)
delta = new_fts - orig delta = new_fts - orig
if delta == 0: if delta == 0:
@@ -659,9 +662,9 @@ def _render_vace_settings(i, seq, batch_list, data, file_path, state, refresh_li
batch_list[j].get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT) + delta batch_list[j].get('frame_to_skip', FRAME_TO_SKIP_DEFAULT), FRAME_TO_SKIP_DEFAULT) + delta
shifted += 1 shifted += 1
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
save_json(file_path, data) await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive') ui.notify(f'Shifted {shifted} sequences by {delta:+d}', type='positive')
refresh_list.refresh() refresh_list.refresh()
@@ -779,7 +782,7 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
select_all_cb.on_value_change(on_select_all) select_all_cb.on_value_change(on_select_all)
def apply_mass_update(): async def apply_mass_update():
src_idx = source_select.value src_idx = source_select.value
if src_idx is None or src_idx >= len(batch_list): if src_idx is None or src_idx >= len(batch_list):
ui.notify('Source sequence no longer exists', type='warning') ui.notify('Source sequence no longer exists', type='warning')
@@ -802,17 +805,17 @@ def _render_mass_update(batch_list, data, file_path, state: AppState, refresh_li
data[KEY_BATCH_DATA] = batch_list data[KEY_BATCH_DATA] = batch_list
htree = HistoryTree(data.get(KEY_HISTORY_TREE, {})) htree = HistoryTree(data.get(KEY_HISTORY_TREE, {}))
snapshot = copy.deepcopy(data) snapshot = {k: copy.deepcopy(v) for k, v in data.items()
snapshot.pop(KEY_HISTORY_TREE, None) if k != KEY_HISTORY_TREE}
try: try:
htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}") htree.commit(snapshot, f"Mass update: {', '.join(selected_keys)}")
except ValueError as e: except ValueError as e:
ui.notify(f'Mass update failed: {e}', type='negative') ui.notify(f'Mass update failed: {e}', type='negative')
return return
data[KEY_HISTORY_TREE] = htree.to_dict() data[KEY_HISTORY_TREE] = htree.to_dict()
save_json(file_path, data) await asyncio.to_thread(save_json, file_path, data)
if state.db_enabled and state.current_project and state.db: if state.db_enabled and state.current_project and state.db:
sync_to_db(state.db, state.current_project, file_path, data) await asyncio.to_thread(sync_to_db, state.db, state.current_project, file_path, data)
ui.notify(f'Updated {len(targets)} sequences', type='positive') ui.notify(f'Updated {len(targets)} sequences', type='positive')
if refresh_list: if refresh_list:
refresh_list.refresh() refresh_list.refresh()