From d937c219ee51edd98aaafbed6f781dd3803fe2b9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 28 Jun 2026 10:27:05 +0200 Subject: [PATCH] Add accumulator retake workflow restore --- __init__.py | 11 +++++ loop_nodes.py | 88 +++++++++++++++++++++++++++++--------- server_routes.py | 12 ++++++ tools/prompt_smoke.py | 15 ++++++- web/accumulator_preview.js | 54 ++++++++++++++++++++++- 5 files changed, 158 insertions(+), 22 deletions(-) diff --git a/__init__.py b/__init__.py index 2e2e697..9b7f4a4 100644 --- a/__init__.py +++ b/__init__.py @@ -99,6 +99,7 @@ try: accumulator_delete_payload, accumulator_list_payload, accumulator_move_payload, + accumulator_retake_payload, accumulator_save_payload, profile_save_cached_payload, ) @@ -155,6 +156,7 @@ except ImportError: accumulator_delete_payload, accumulator_list_payload, accumulator_move_payload, + accumulator_retake_payload, accumulator_save_payload, profile_save_cached_payload, ) @@ -206,6 +208,15 @@ if PromptServer is not None and web is not None: except Exception as exc: return web.json_response({"error": str(exc)}, status=400) + @PromptServer.instance.routes.post("/sxcp/accumulator/retake") + async def sxcp_accumulator_retake(request): + try: + payload = await request.json() + result = accumulator_retake_payload(payload) + return web.json_response(result) + except Exception as exc: + return web.json_response({"error": str(exc)}, status=400) + NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS.update(BUILDER_NODE_CLASS_MAPPINGS) diff --git a/loop_nodes.py b/loop_nodes.py index 16501e2..3186b18 100644 --- a/loop_nodes.py +++ b/loop_nodes.py @@ -345,6 +345,71 @@ def accumulator_list_entries(store_key: str, preview_limit: int = 0) -> dict[str return result +def _find_accumulator_entry( + store: list[dict[str, Any]], + preview_key: str = "", + entry_id: str = "", + index: int = 0, +) -> tuple[int, dict[str, Any]]: + preview_key = str(preview_key or "").strip() + entry_id = str(entry_id or "").strip() + if preview_key: + for current_index, entry in enumerate(store): + if _entry_preview_key(entry) == preview_key: + return current_index, entry + elif entry_id: + for current_index, entry in enumerate(store): + if str(entry.get("id") or "") == entry_id: + return current_index, entry + elif int(index) > 0: + zero_index = int(index) - 1 + if 0 <= zero_index < len(store): + return zero_index, store[zero_index] + else: + raise ValueError("entry_id or 1-based index is required") + raise ValueError("accumulator entry not found") + + +def _entry_workflow(entry: dict[str, Any]) -> Any: + extra_pnginfo = entry.get("extra_pnginfo") + if isinstance(extra_pnginfo, dict): + workflow = extra_pnginfo.get("workflow") or extra_pnginfo.get("Workflow") + if workflow: + return _metadata_copy(workflow) + prompt = entry.get("prompt") + if isinstance(prompt, dict): + workflow = prompt.get("workflow") + if workflow: + return _metadata_copy(workflow) + return None + + +def accumulator_retake_entry( + store_key: str, + preview_key: str = "", + entry_id: str = "", + index: int = 0, +) -> dict[str, Any]: + key = str(store_key or "").strip() + if not key: + raise ValueError("store_key is required for accumulator retake") + store = _ACCUMULATOR_STORES.setdefault(key, []) + zero_index, entry = _find_accumulator_entry(store, preview_key=preview_key, entry_id=entry_id, index=index) + workflow = _entry_workflow(entry) + if workflow is None: + raise ValueError("selected accumulator entry does not include workflow metadata") + entry_info = _entry_infos([entry])[0] + entry_info["index"] = zero_index + 1 + return { + "store_key": key, + "entry": entry_info, + "index": zero_index + 1, + "workflow": workflow, + "prompt": _metadata_copy(entry.get("prompt")), + "extra_pnginfo": _metadata_copy(entry.get("extra_pnginfo")) if isinstance(entry.get("extra_pnginfo"), dict) else {}, + } + + def accumulator_delete_entries( store_key: str, preview_key: str = "", @@ -426,26 +491,9 @@ def accumulator_move_entry( result = accumulator_list_entries(key, preview_limit=preview_limit) result["moved"] = False return result - zero_index = -1 - preview_key = str(preview_key or "").strip() - entry_id = str(entry_id or "").strip() - if preview_key: - for current_index, entry in enumerate(store): - if _entry_preview_key(entry) == preview_key: - zero_index = current_index - break - elif entry_id: - for current_index, entry in enumerate(store): - if str(entry.get("id") or "") == entry_id: - zero_index = current_index - break - elif int(index) > 0: - candidate = int(index) - 1 - if candidate < len(store): - zero_index = candidate - else: - raise ValueError("entry_id or 1-based index is required") - if zero_index < 0: + try: + zero_index, _entry = _find_accumulator_entry(store, preview_key=preview_key, entry_id=entry_id, index=index) + except ValueError: result = accumulator_list_entries(key, preview_limit=preview_limit) result["moved"] = False return result diff --git a/server_routes.py b/server_routes.py index 5c3f928..c97b8b3 100644 --- a/server_routes.py +++ b/server_routes.py @@ -7,6 +7,7 @@ try: accumulator_delete_entries, accumulator_list_entries, accumulator_move_entry, + accumulator_retake_entry, accumulator_save_entries, ) from .prompt_builder import save_character_profile_payload @@ -15,6 +16,7 @@ except ImportError: # Allows local smoke tests from the repository root. accumulator_delete_entries, accumulator_list_entries, accumulator_move_entry, + accumulator_retake_entry, accumulator_save_entries, ) from prompt_builder import save_character_profile_payload @@ -74,3 +76,13 @@ def accumulator_move_payload(payload: Any) -> dict[str, Any]: target_index=int(data.get("target_index") or 0), preview_limit=int(data.get("preview_limit") or 0), ) + + +def accumulator_retake_payload(payload: Any) -> dict[str, Any]: + data = _payload(payload) + return accumulator_retake_entry( + store_key=str(data.get("store_key") or ""), + preview_key=str(data.get("preview_key") or ""), + entry_id=str(data.get("entry_id") or ""), + index=int(data.get("index") or 0), + ) diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index bc1dfa8..c781025 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -8004,7 +8004,13 @@ def smoke_server_route_payload_policy() -> None: key = "smoke_route_payload" loop_nodes._ACCUMULATOR_STORES[key] = [ - {"id": "first", "value": "alpha", "_sxcp_preview_key": "first-key"}, + { + "id": "first", + "value": "alpha", + "_sxcp_preview_key": "first-key", + "prompt": {"api": "prompt"}, + "extra_pnginfo": {"workflow": {"nodes": [{"id": 1, "type": "SmokeNode"}]}}, + }, {"id": "second", "value": "beta", "_sxcp_preview_key": "second-key"}, ] try: @@ -8012,6 +8018,13 @@ def smoke_server_route_payload_policy() -> None: _expect(listed.get("count") == 2, "Accumulator list payload lost stored entries") _expect(listed["entries"][0].get("value") == "alpha", "Accumulator list payload lost value summary") + retake = server_routes.accumulator_retake_payload({"store_key": key, "preview_key": "first-key"}) + _expect( + retake.get("workflow", {}).get("nodes", [{}])[0].get("type") == "SmokeNode", + "Accumulator retake payload lost workflow metadata", + ) + _expect(retake.get("prompt", {}).get("api") == "prompt", "Accumulator retake payload lost prompt metadata") + moved = server_routes.accumulator_move_payload({"store_key": key, "entry_id": "second", "target_index": "1"}) _expect(moved.get("moved") is True, "Accumulator move payload did not report movement") _expect(moved.get("from_index") == 2 and moved.get("to_index") == 1, "Accumulator move payload changed indices") diff --git a/web/accumulator_preview.js b/web/accumulator_preview.js index 404b07f..9ad33a3 100644 --- a/web/accumulator_preview.js +++ b/web/accumulator_preview.js @@ -247,6 +247,24 @@ async function postJson(path, payload) { return data; } +function workflowFromRetakeData(data) { + const workflow = data?.workflow; + if (!workflow) throw new Error("No workflow metadata found on this accumulator entry."); + if (typeof workflow === "string") return JSON.parse(workflow); + return workflow; +} + +async function loadWorkflow(workflow) { + if (typeof app.loadGraphData === "function") { + await app.loadGraphData(workflow); + return; + } + if (!app.graph?.configure) throw new Error("This ComfyUI frontend cannot load workflow data."); + app.graph.clear?.(); + app.graph.configure(workflow); + app.graph.setDirtyCanvas?.(true, true); +} + function injectStyles() { if (document.getElementById(STYLE_ID)) return; const css = ` @@ -421,6 +439,12 @@ function renderCell(node, entry, imageParams, displayIndex, options = {}) { } thumb.draggable = true; thumb.onclick = () => markSelected(node, entry.index); + thumb.oncontextmenu = async (event) => { + event.preventDefault(); + event.stopPropagation(); + markSelected(node, entry.index); + await retakeEntry(node, entry); + }; thumb.ondragstart = (event) => { node._sxapDragEntry = entry; debugLog("dragstart", {index: entry.index, key: entryKey(entry), id: entry.id}); @@ -444,7 +468,7 @@ function renderCell(node, entry, imageParams, displayIndex, options = {}) { const meta = document.createElement("div"); meta.className = "sxap-meta"; meta.textContent = "M"; - meta.title = "has metadata"; + meta.title = "has metadata; right-click image to retake"; cell.appendChild(meta); } @@ -668,6 +692,34 @@ async function saveBatch(node) { } } +async function retakeEntry(node, entry) { + const key = storeKey(node); + if (!key) { + alert("Set the same explicit store_key on the Accumulator and Accumulator Preview first."); + return; + } + if (!entry) return; + if (!entry.has_metadata) { + alert("This accumulator entry has no metadata to retake from."); + return; + } + const label = entry.id || `#${entry.index}`; + if (!confirm(`Retake ${label}? This will replace the current workflow with the workflow metadata saved on that entry.`)) return; + try { + const data = await postJson("/sxcp/accumulator/retake", { + store_key: key, + preview_key: entry.preview_key || "", + entry_id: entry.id || "", + index: entry.preview_key || entry.id ? 0 : entry.index, + }); + const workflow = workflowFromRetakeData(data); + await loadWorkflow(workflow); + } catch (err) { + console.error(`[${EXTENSION}] retake failed`, err); + alert(`Retake failed: ${err}`); + } +} + function hideInternalWidgets(node) { for (const name of [ "delete_action",