Add accumulator retake workflow restore
This commit is contained in:
+11
@@ -99,6 +99,7 @@ try:
|
||||
accumulator_delete_payload,
|
||||
accumulator_list_payload,
|
||||
accumulator_move_payload,
|
||||
accumulator_retake_payload,
|
||||
accumulator_save_payload,
|
||||
profile_save_cached_payload,
|
||||
)
|
||||
@@ -155,6 +156,7 @@ except ImportError:
|
||||
accumulator_delete_payload,
|
||||
accumulator_list_payload,
|
||||
accumulator_move_payload,
|
||||
accumulator_retake_payload,
|
||||
accumulator_save_payload,
|
||||
profile_save_cached_payload,
|
||||
)
|
||||
@@ -206,6 +208,15 @@ if PromptServer is not None and web is not None:
|
||||
except Exception as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
|
||||
@PromptServer.instance.routes.post("/sxcp/accumulator/retake")
|
||||
async def sxcp_accumulator_retake(request):
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = accumulator_retake_payload(payload)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(BUILDER_NODE_CLASS_MAPPINGS)
|
||||
|
||||
+68
-20
@@ -345,6 +345,71 @@ def accumulator_list_entries(store_key: str, preview_limit: int = 0) -> dict[str
|
||||
return result
|
||||
|
||||
|
||||
def _find_accumulator_entry(
|
||||
store: list[dict[str, Any]],
|
||||
preview_key: str = "",
|
||||
entry_id: str = "",
|
||||
index: int = 0,
|
||||
) -> tuple[int, dict[str, Any]]:
|
||||
preview_key = str(preview_key or "").strip()
|
||||
entry_id = str(entry_id or "").strip()
|
||||
if preview_key:
|
||||
for current_index, entry in enumerate(store):
|
||||
if _entry_preview_key(entry) == preview_key:
|
||||
return current_index, entry
|
||||
elif entry_id:
|
||||
for current_index, entry in enumerate(store):
|
||||
if str(entry.get("id") or "") == entry_id:
|
||||
return current_index, entry
|
||||
elif int(index) > 0:
|
||||
zero_index = int(index) - 1
|
||||
if 0 <= zero_index < len(store):
|
||||
return zero_index, store[zero_index]
|
||||
else:
|
||||
raise ValueError("entry_id or 1-based index is required")
|
||||
raise ValueError("accumulator entry not found")
|
||||
|
||||
|
||||
def _entry_workflow(entry: dict[str, Any]) -> Any:
|
||||
extra_pnginfo = entry.get("extra_pnginfo")
|
||||
if isinstance(extra_pnginfo, dict):
|
||||
workflow = extra_pnginfo.get("workflow") or extra_pnginfo.get("Workflow")
|
||||
if workflow:
|
||||
return _metadata_copy(workflow)
|
||||
prompt = entry.get("prompt")
|
||||
if isinstance(prompt, dict):
|
||||
workflow = prompt.get("workflow")
|
||||
if workflow:
|
||||
return _metadata_copy(workflow)
|
||||
return None
|
||||
|
||||
|
||||
def accumulator_retake_entry(
|
||||
store_key: str,
|
||||
preview_key: str = "",
|
||||
entry_id: str = "",
|
||||
index: int = 0,
|
||||
) -> dict[str, Any]:
|
||||
key = str(store_key or "").strip()
|
||||
if not key:
|
||||
raise ValueError("store_key is required for accumulator retake")
|
||||
store = _ACCUMULATOR_STORES.setdefault(key, [])
|
||||
zero_index, entry = _find_accumulator_entry(store, preview_key=preview_key, entry_id=entry_id, index=index)
|
||||
workflow = _entry_workflow(entry)
|
||||
if workflow is None:
|
||||
raise ValueError("selected accumulator entry does not include workflow metadata")
|
||||
entry_info = _entry_infos([entry])[0]
|
||||
entry_info["index"] = zero_index + 1
|
||||
return {
|
||||
"store_key": key,
|
||||
"entry": entry_info,
|
||||
"index": zero_index + 1,
|
||||
"workflow": workflow,
|
||||
"prompt": _metadata_copy(entry.get("prompt")),
|
||||
"extra_pnginfo": _metadata_copy(entry.get("extra_pnginfo")) if isinstance(entry.get("extra_pnginfo"), dict) else {},
|
||||
}
|
||||
|
||||
|
||||
def accumulator_delete_entries(
|
||||
store_key: str,
|
||||
preview_key: str = "",
|
||||
@@ -426,26 +491,9 @@ def accumulator_move_entry(
|
||||
result = accumulator_list_entries(key, preview_limit=preview_limit)
|
||||
result["moved"] = False
|
||||
return result
|
||||
zero_index = -1
|
||||
preview_key = str(preview_key or "").strip()
|
||||
entry_id = str(entry_id or "").strip()
|
||||
if preview_key:
|
||||
for current_index, entry in enumerate(store):
|
||||
if _entry_preview_key(entry) == preview_key:
|
||||
zero_index = current_index
|
||||
break
|
||||
elif entry_id:
|
||||
for current_index, entry in enumerate(store):
|
||||
if str(entry.get("id") or "") == entry_id:
|
||||
zero_index = current_index
|
||||
break
|
||||
elif int(index) > 0:
|
||||
candidate = int(index) - 1
|
||||
if candidate < len(store):
|
||||
zero_index = candidate
|
||||
else:
|
||||
raise ValueError("entry_id or 1-based index is required")
|
||||
if zero_index < 0:
|
||||
try:
|
||||
zero_index, _entry = _find_accumulator_entry(store, preview_key=preview_key, entry_id=entry_id, index=index)
|
||||
except ValueError:
|
||||
result = accumulator_list_entries(key, preview_limit=preview_limit)
|
||||
result["moved"] = False
|
||||
return result
|
||||
|
||||
@@ -7,6 +7,7 @@ try:
|
||||
accumulator_delete_entries,
|
||||
accumulator_list_entries,
|
||||
accumulator_move_entry,
|
||||
accumulator_retake_entry,
|
||||
accumulator_save_entries,
|
||||
)
|
||||
from .prompt_builder import save_character_profile_payload
|
||||
@@ -15,6 +16,7 @@ except ImportError: # Allows local smoke tests from the repository root.
|
||||
accumulator_delete_entries,
|
||||
accumulator_list_entries,
|
||||
accumulator_move_entry,
|
||||
accumulator_retake_entry,
|
||||
accumulator_save_entries,
|
||||
)
|
||||
from prompt_builder import save_character_profile_payload
|
||||
@@ -74,3 +76,13 @@ def accumulator_move_payload(payload: Any) -> dict[str, Any]:
|
||||
target_index=int(data.get("target_index") or 0),
|
||||
preview_limit=int(data.get("preview_limit") or 0),
|
||||
)
|
||||
|
||||
|
||||
def accumulator_retake_payload(payload: Any) -> dict[str, Any]:
|
||||
data = _payload(payload)
|
||||
return accumulator_retake_entry(
|
||||
store_key=str(data.get("store_key") or ""),
|
||||
preview_key=str(data.get("preview_key") or ""),
|
||||
entry_id=str(data.get("entry_id") or ""),
|
||||
index=int(data.get("index") or 0),
|
||||
)
|
||||
|
||||
+14
-1
@@ -8004,7 +8004,13 @@ def smoke_server_route_payload_policy() -> None:
|
||||
|
||||
key = "smoke_route_payload"
|
||||
loop_nodes._ACCUMULATOR_STORES[key] = [
|
||||
{"id": "first", "value": "alpha", "_sxcp_preview_key": "first-key"},
|
||||
{
|
||||
"id": "first",
|
||||
"value": "alpha",
|
||||
"_sxcp_preview_key": "first-key",
|
||||
"prompt": {"api": "prompt"},
|
||||
"extra_pnginfo": {"workflow": {"nodes": [{"id": 1, "type": "SmokeNode"}]}},
|
||||
},
|
||||
{"id": "second", "value": "beta", "_sxcp_preview_key": "second-key"},
|
||||
]
|
||||
try:
|
||||
@@ -8012,6 +8018,13 @@ def smoke_server_route_payload_policy() -> None:
|
||||
_expect(listed.get("count") == 2, "Accumulator list payload lost stored entries")
|
||||
_expect(listed["entries"][0].get("value") == "alpha", "Accumulator list payload lost value summary")
|
||||
|
||||
retake = server_routes.accumulator_retake_payload({"store_key": key, "preview_key": "first-key"})
|
||||
_expect(
|
||||
retake.get("workflow", {}).get("nodes", [{}])[0].get("type") == "SmokeNode",
|
||||
"Accumulator retake payload lost workflow metadata",
|
||||
)
|
||||
_expect(retake.get("prompt", {}).get("api") == "prompt", "Accumulator retake payload lost prompt metadata")
|
||||
|
||||
moved = server_routes.accumulator_move_payload({"store_key": key, "entry_id": "second", "target_index": "1"})
|
||||
_expect(moved.get("moved") is True, "Accumulator move payload did not report movement")
|
||||
_expect(moved.get("from_index") == 2 and moved.get("to_index") == 1, "Accumulator move payload changed indices")
|
||||
|
||||
@@ -247,6 +247,24 @@ async function postJson(path, payload) {
|
||||
return data;
|
||||
}
|
||||
|
||||
function workflowFromRetakeData(data) {
|
||||
const workflow = data?.workflow;
|
||||
if (!workflow) throw new Error("No workflow metadata found on this accumulator entry.");
|
||||
if (typeof workflow === "string") return JSON.parse(workflow);
|
||||
return workflow;
|
||||
}
|
||||
|
||||
async function loadWorkflow(workflow) {
|
||||
if (typeof app.loadGraphData === "function") {
|
||||
await app.loadGraphData(workflow);
|
||||
return;
|
||||
}
|
||||
if (!app.graph?.configure) throw new Error("This ComfyUI frontend cannot load workflow data.");
|
||||
app.graph.clear?.();
|
||||
app.graph.configure(workflow);
|
||||
app.graph.setDirtyCanvas?.(true, true);
|
||||
}
|
||||
|
||||
function injectStyles() {
|
||||
if (document.getElementById(STYLE_ID)) return;
|
||||
const css = `
|
||||
@@ -421,6 +439,12 @@ function renderCell(node, entry, imageParams, displayIndex, options = {}) {
|
||||
}
|
||||
thumb.draggable = true;
|
||||
thumb.onclick = () => markSelected(node, entry.index);
|
||||
thumb.oncontextmenu = async (event) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
markSelected(node, entry.index);
|
||||
await retakeEntry(node, entry);
|
||||
};
|
||||
thumb.ondragstart = (event) => {
|
||||
node._sxapDragEntry = entry;
|
||||
debugLog("dragstart", {index: entry.index, key: entryKey(entry), id: entry.id});
|
||||
@@ -444,7 +468,7 @@ function renderCell(node, entry, imageParams, displayIndex, options = {}) {
|
||||
const meta = document.createElement("div");
|
||||
meta.className = "sxap-meta";
|
||||
meta.textContent = "M";
|
||||
meta.title = "has metadata";
|
||||
meta.title = "has metadata; right-click image to retake";
|
||||
cell.appendChild(meta);
|
||||
}
|
||||
|
||||
@@ -668,6 +692,34 @@ async function saveBatch(node) {
|
||||
}
|
||||
}
|
||||
|
||||
async function retakeEntry(node, entry) {
|
||||
const key = storeKey(node);
|
||||
if (!key) {
|
||||
alert("Set the same explicit store_key on the Accumulator and Accumulator Preview first.");
|
||||
return;
|
||||
}
|
||||
if (!entry) return;
|
||||
if (!entry.has_metadata) {
|
||||
alert("This accumulator entry has no metadata to retake from.");
|
||||
return;
|
||||
}
|
||||
const label = entry.id || `#${entry.index}`;
|
||||
if (!confirm(`Retake ${label}? This will replace the current workflow with the workflow metadata saved on that entry.`)) return;
|
||||
try {
|
||||
const data = await postJson("/sxcp/accumulator/retake", {
|
||||
store_key: key,
|
||||
preview_key: entry.preview_key || "",
|
||||
entry_id: entry.id || "",
|
||||
index: entry.preview_key || entry.id ? 0 : entry.index,
|
||||
});
|
||||
const workflow = workflowFromRetakeData(data);
|
||||
await loadWorkflow(workflow);
|
||||
} catch (err) {
|
||||
console.error(`[${EXTENSION}] retake failed`, err);
|
||||
alert(`Retake failed: ${err}`);
|
||||
}
|
||||
}
|
||||
|
||||
function hideInternalWidgets(node) {
|
||||
for (const name of [
|
||||
"delete_action",
|
||||
|
||||
Reference in New Issue
Block a user