Add accumulator preview and batch save node

This commit is contained in:
2026-06-25 02:19:07 +02:00
parent 0b1b79445e
commit 71f4f162eb
3 changed files with 583 additions and 15 deletions
+46 -2
View File
@@ -117,6 +117,15 @@ COMMON_INPUT_TOOLTIPS = {
"image": "Image to store in the accumulator.",
"entry_id": "Stable ID used for replace_by_entry_id or grouping variants.",
"entry_tag": "Optional suffix added to entry_id.",
"preview_limit": "Maximum number of accumulator images to show in the preview panel.",
"delete_action": "Optional execution-time delete operation. JS buttons can delete interactively without setting this.",
"delete_entry_id": "Entry id to delete when delete_action is delete_entry_id.",
"delete_index": "1-based entry index to delete when delete_action is delete_index. 0 disables it.",
"save_batch": "When enabled, save all current accumulator images once finished is true.",
"finished": "Gate for saving. Outside a loop, leave true; inside a loop, wire a final-iteration signal.",
"save_path": "Folder to save the accumulator batch. Relative paths are inside ComfyUI output; absolute paths are used directly.",
"filename_prefix": "Filename prefix for saved accumulator images.",
"clear_after_save": "Clear the accumulator store after a successful batch save.",
"clothing": "Built-in clothing density for legacy direct generation. Category/profile nodes can override this.",
"poses": "Built-in pose pool for legacy direct generation.",
"backside_bias": "Legacy bias toward rear/backside poses where that category supports it.",
@@ -317,7 +326,13 @@ def _install_input_tooltips(node_classes: dict[str, type]) -> None:
node_class._sxcp_tooltips_installed = True
try:
from .loop_nodes import ANY_TYPE, LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS
from .loop_nodes import (
ANY_TYPE,
LOOP_NODE_CLASS_MAPPINGS,
LOOP_NODE_DISPLAY_NAME_MAPPINGS,
accumulator_delete_entries,
accumulator_list_entries,
)
from .prompt_builder import (
build_camera_config_json,
build_camera_orbit_config_json,
@@ -386,7 +401,13 @@ try:
from .caption_naturalizer import naturalize_caption
from .krea_formatter import format_krea2_prompt
except ImportError:
from loop_nodes import ANY_TYPE, LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS
from loop_nodes import (
ANY_TYPE,
LOOP_NODE_CLASS_MAPPINGS,
LOOP_NODE_DISPLAY_NAME_MAPPINGS,
accumulator_delete_entries,
accumulator_list_entries,
)
from prompt_builder import (
build_camera_config_json,
build_camera_orbit_config_json,
@@ -469,6 +490,29 @@ if PromptServer is not None and web is not None:
except Exception as exc:
return web.json_response({"error": str(exc)}, status=400)
@PromptServer.instance.routes.post("/sxcp/accumulator/list")
async def sxcp_accumulator_list(request):
try:
payload = await request.json()
result = accumulator_list_entries(str(payload.get("store_key") or ""))
return web.json_response(result)
except Exception as exc:
return web.json_response({"error": str(exc)}, status=400)
@PromptServer.instance.routes.post("/sxcp/accumulator/delete")
async def sxcp_accumulator_delete(request):
try:
payload = await request.json()
result = accumulator_delete_entries(
store_key=str(payload.get("store_key") or ""),
entry_id=str(payload.get("entry_id") or ""),
index=int(payload.get("index") or 0),
clear=bool(payload.get("clear")),
)
return web.json_response(result)
except Exception as exc:
return web.json_response({"error": str(exc)}, status=400)
class SxCPPromptBuilder:
@classmethod
+317 -13
View File
@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import os
import random
import re
from typing import Any
try:
@@ -18,6 +21,23 @@ try:
except Exception:
ALL_NODE_CLASS_MAPPINGS = {}
try:
import folder_paths
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from comfy.cli_args import args
except Exception:
folder_paths = None
np = None
Image = None
PngInfo = None
class _ArgsFallback:
disable_metadata = True
args = _ArgsFallback()
MAX_LOOP_VALUES = 20
MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2
@@ -25,6 +45,7 @@ COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string
ACCUMULATOR_ACTIONS = ["append_variant", "replace_by_entry_id", "append", "clear_then_append", "clear", "read"]
ACCUMULATOR_IMAGE_BATCH_MODES = ["same_size_only", "resize_to_first"]
ACCUMULATOR_IMAGE_GROUPS = 4
ACCUMULATOR_PREVIEW_DELETE_ACTIONS = ["none", "delete_entry_id", "delete_index", "clear"]
_ACCUMULATOR_STORES: dict[str, list[dict[str, Any]]] = {}
@@ -154,6 +175,197 @@ def _group_image_batches(images: list[Any]) -> list[Any]:
return [batch for batch in batches if batch is not None]
def _accumulator_store_key(store_key: str, unique_id: Any = None) -> str:
key = str(store_key or "").strip()
if key:
return key
return f"node:{unique_id}"
def _entry_value_summary(value: Any) -> str:
if value is None:
return ""
text = str(value)
text = re.sub(r"\s+", " ", text).strip()
return text[:160]
def _entry_infos(store: list[dict[str, Any]]) -> list[dict[str, Any]]:
entries = []
for index, entry in enumerate(store, start=1):
image = entry.get("image")
shape = _image_shape(image)
entries.append(
{
"index": index,
"id": str(entry.get("id") or ""),
"has_image": image is not None,
"shape": list(shape) if shape is not None else [],
"value": _entry_value_summary(entry.get("value")),
}
)
return entries
def _accumulator_status(key: str, store: list[dict[str, Any]]) -> str:
images = [entry.get("image") for entry in store if entry.get("image") is not None]
shapes = []
for image in images:
shape = _image_shape(image)
if shape is not None and shape not in shapes:
shapes.append(shape)
shape_text = ", ".join(f"{shape[1]}x{shape[0]}" for shape in shapes) or "no images"
return f"key={key}; entries={len(store)}; image_entries={len(images)}; formats={shape_text}"
def accumulator_list_entries(store_key: str) -> dict[str, Any]:
key = str(store_key or "").strip()
if not key:
raise ValueError("store_key is required for accumulator preview actions")
store = _ACCUMULATOR_STORES.setdefault(key, [])
return {
"store_key": key,
"entries": _entry_infos(store),
"count": len(store),
"status": _accumulator_status(key, store),
}
def accumulator_delete_entries(
store_key: str,
entry_id: str = "",
index: int = 0,
clear: bool = False,
) -> dict[str, Any]:
key = str(store_key or "").strip()
if not key:
raise ValueError("store_key is required for accumulator preview actions")
store = _ACCUMULATOR_STORES.setdefault(key, [])
removed = 0
if clear:
removed = len(store)
store.clear()
else:
entry_id = str(entry_id or "").strip()
if entry_id:
before = len(store)
store[:] = [entry for entry in store if str(entry.get("id") or "") != entry_id]
removed = before - len(store)
elif int(index) > 0:
zero_index = int(index) - 1
if zero_index < len(store):
del store[zero_index]
removed = 1
else:
raise ValueError("entry_id or 1-based index is required")
result = accumulator_list_entries(key)
result["removed"] = removed
return result
def _require_image_saving() -> None:
if folder_paths is None or np is None or Image is None:
raise RuntimeError("Image preview/save helpers require ComfyUI image dependencies.")
def _metadata(prompt: Any, extra_pnginfo: Any) -> Any:
if args.disable_metadata or PngInfo is None:
return None
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for key, value in extra_pnginfo.items():
metadata.add_text(key, json.dumps(value))
return metadata
def _image_to_pil(image: Any) -> Any:
_require_image_saving()
tensor = image
try:
if len(tensor.shape) == 4:
tensor = tensor[0]
except Exception:
pass
image_data = 255.0 * tensor.cpu().numpy()
return Image.fromarray(np.clip(image_data, 0, 255).astype(np.uint8))
def _safe_filename_prefix(prefix: str, default: str = "sxcp_accumulator") -> str:
text = str(prefix or "").strip() or default
text = os.path.basename(text)
text = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._")
return text or default
def _next_save_counter(folder: str, prefix: str) -> int:
pattern = re.compile(rf"^{re.escape(prefix)}_(\d+)_\.png$")
counter = 1
try:
for filename in os.listdir(folder):
match = pattern.match(filename)
if match:
counter = max(counter, int(match.group(1)) + 1)
except FileNotFoundError:
pass
return counter
def _preview_image_results(images: list[Any], preview_limit: int, prompt: Any, extra_pnginfo: Any) -> list[dict[str, str]]:
if not images:
return []
_require_image_saving()
output_dir = folder_paths.get_temp_directory()
preview_images = images[: max(1, int(preview_limit))]
first_shape = _image_shape(preview_images[0])
height, width = (first_shape[0], first_shape[1]) if first_shape else (512, 512)
prefix = "SxCPAccumulatorPreview_temp_" + "".join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5))
full_output_folder, filename, counter, subfolder, _filename_prefix = folder_paths.get_save_image_path(prefix, output_dir, width, height)
metadata = _metadata(prompt, extra_pnginfo)
results = []
for image in preview_images:
file = f"{filename}_{counter:05}_.png"
_image_to_pil(image).save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=1)
results.append({"filename": file, "subfolder": subfolder, "type": "temp"})
counter += 1
return results
def _resolve_save_folder(save_path: str) -> str:
_require_image_saving()
raw_path = os.path.expanduser(str(save_path or "").strip())
if not raw_path:
raw_path = "sxcp_accumulator"
if os.path.isabs(raw_path):
return raw_path
return os.path.join(folder_paths.get_output_directory(), raw_path)
def _save_images_to_folder(
images: list[Any],
save_path: str,
filename_prefix: str,
prompt: Any,
extra_pnginfo: Any,
) -> list[str]:
if not images:
return []
folder = _resolve_save_folder(save_path)
os.makedirs(folder, exist_ok=True)
prefix = _safe_filename_prefix(filename_prefix)
counter = _next_save_counter(folder, prefix)
metadata = _metadata(prompt, extra_pnginfo)
saved_paths = []
for image in images:
file = f"{prefix}_{counter:05}_.png"
path = os.path.join(folder, file)
_image_to_pil(image).save(path, pnginfo=metadata, compress_level=4)
saved_paths.append(path)
counter += 1
return saved_paths
def _as_list(collection: Any) -> list[Any]:
if collection is None:
return []
@@ -449,8 +661,7 @@ class SxCPAccumulator:
return random.random()
def _store_key(self, store_key: str, unique_id: Any) -> str:
key = str(store_key or "").strip()
return key or f"node:{unique_id}"
return _accumulator_store_key(store_key, unique_id)
def _entry_id(self, entry_id: Any, entry_tag: str, image_index: int, image_count: int) -> str:
text = "" if entry_id is None else str(entry_id).strip()
@@ -523,18 +734,8 @@ class SxCPAccumulator:
return collection
def _status(self, key: str, store: list[dict[str, Any]], image_batch: Any, image_batches: list[Any]) -> str:
images = [entry.get("image") for entry in store if entry.get("image") is not None]
shapes = []
for image in images:
shape = _image_shape(image)
if shape is not None and shape not in shapes:
shapes.append(shape)
shape_text = ", ".join(f"{shape[1]}x{shape[0]}" for shape in shapes) or "no images"
batch_state = "all images batched" if image_batch is not None else "mixed sizes or no image batch"
return (
f"key={key}; entries={len(store)}; image_entries={len(images)}; "
f"formats={shape_text}; grouped_batches={len(image_batches)}; {batch_state}"
)
return f"{_accumulator_status(key, store)}; grouped_batches={len(image_batches)}; {batch_state}"
def accumulate(
self,
@@ -573,6 +774,107 @@ class SxCPAccumulator:
return tuple([self._collection(store), image_batch, images] + grouped_outputs + [len(store), status])
class SxCPAccumulatorPreview:
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"store_key": ("STRING", {"default": "", "multiline": False}),
"preview_limit": ("INT", {"default": 64, "min": 1, "max": 512, "step": 1}),
"delete_action": (ACCUMULATOR_PREVIEW_DELETE_ACTIONS, {"default": "none"}),
"delete_entry_id": ("STRING", {"default": "", "multiline": False}),
"delete_index": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}),
"save_batch": ("BOOLEAN", {"default": False}),
"finished": ("BOOLEAN", {"default": True}),
"save_path": ("STRING", {"default": "sxcp_accumulator", "multiline": False}),
"filename_prefix": ("STRING", {"default": "sxcp_accum", "multiline": False}),
"clear_after_save": ("BOOLEAN", {"default": False}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("INT", "STRING", "STRING")
RETURN_NAMES = ("count", "status", "saved_paths_json")
FUNCTION = "preview"
CATEGORY = "prompt_builder/loop"
@classmethod
def IS_CHANGED(cls, *args, **kwargs):
return random.random()
def _delete_from_inputs(self, key: str, delete_action: str, delete_entry_id: str, delete_index: int) -> int:
action = delete_action if delete_action in ACCUMULATOR_PREVIEW_DELETE_ACTIONS else "none"
if action == "none":
return 0
if action == "clear":
return int(accumulator_delete_entries(key, clear=True).get("removed", 0))
if action == "delete_entry_id":
entry_id = str(delete_entry_id or "").strip()
if not entry_id:
return 0
return int(accumulator_delete_entries(key, entry_id=entry_id).get("removed", 0))
if action == "delete_index":
index = int(delete_index)
if index <= 0:
return 0
return int(accumulator_delete_entries(key, index=index).get("removed", 0))
return 0
def preview(
self,
store_key,
preview_limit,
delete_action,
delete_entry_id,
delete_index,
save_batch,
finished,
save_path,
filename_prefix,
clear_after_save,
prompt=None,
extra_pnginfo=None,
unique_id=None,
):
key = _accumulator_store_key(store_key, unique_id)
store = _ACCUMULATOR_STORES.setdefault(key, [])
removed = self._delete_from_inputs(key, delete_action, delete_entry_id, delete_index)
images = [entry["image"] for entry in store if entry.get("image") is not None]
saved_paths: list[str] = []
save_status = ""
if bool(save_batch) and bool(finished):
saved_paths = _save_images_to_folder(images, save_path, filename_prefix, prompt, extra_pnginfo)
save_status = f"; saved={len(saved_paths)}"
if saved_paths and bool(clear_after_save):
store.clear()
images = []
save_status += "; cleared_after_save"
preview_images = _preview_image_results(images, preview_limit, prompt, extra_pnginfo)
entries = _entry_infos(store)
status = _accumulator_status(key, store)
if removed:
status += f"; removed={removed}"
status += save_status
saved_json = json.dumps(saved_paths, ensure_ascii=True)
return {
"ui": {
"images": preview_images,
"entries": entries,
"status": [status],
"saved_paths": saved_paths,
},
"result": (len(store), status, saved_json),
}
class SxCPForLoopEnd:
@classmethod
def INPUT_TYPES(cls):
@@ -697,6 +999,7 @@ LOOP_NODE_CLASS_MAPPINGS = {
"SxCPForLoopEnd": SxCPForLoopEnd,
"SxCPLoopAppend": SxCPLoopAppend,
"SxCPAccumulator": SxCPAccumulator,
"SxCPAccumulatorPreview": SxCPAccumulatorPreview,
"SxCPLoopIntAdd": SxCPLoopIntAdd,
"SxCPLoopLessThan": SxCPLoopLessThan,
"SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual,
@@ -709,6 +1012,7 @@ LOOP_NODE_DISPLAY_NAME_MAPPINGS = {
"SxCPForLoopEnd": "SxCP For Loop End",
"SxCPLoopAppend": "SxCP Loop Append",
"SxCPAccumulator": "SxCP Accumulator",
"SxCPAccumulatorPreview": "SxCP Accumulator Preview",
"SxCPLoopIntAdd": "SxCP Loop Int Add",
"SxCPLoopLessThan": "SxCP Loop Less Than",
"SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal",
+220
View File
@@ -0,0 +1,220 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
const EXTENSION = "ethanfel.prompt_builder.accumulator_preview";
const NODE_NAME = "SxCPAccumulatorPreview";
const entryCache = new Map();
function widget(node, name) {
return node.widgets?.find((w) => w.name === name);
}
function hideWidget(w) {
if (!w) return;
if (w.origType === undefined) w.origType = w.type;
w.type = "hidden";
w.hidden = true;
w.computeSize = () => [0, -4];
}
function resizeNode(node) {
const size = node.computeSize?.();
if (size) node.setSize?.(size);
app.graph?.setDirtyCanvas(true, true);
}
function nodeKey(nodeOrId) {
return String(typeof nodeOrId === "object" ? nodeOrId?.id : nodeOrId);
}
function isAccumulatorPreviewNode(node) {
return node?.comfyClass === NODE_NAME || node?.type === NODE_NAME;
}
function getNodeById(id) {
return app.graph?.getNodeById?.(Number(id)) || app.graph?._nodes_by_id?.[id] || app.graph?._nodes_by_id?.[Number(id)];
}
function asArray(value) {
if (!value) return [];
return Array.isArray(value) ? value : [value];
}
function outputStatus(output) {
const status = output?.status;
if (Array.isArray(status)) return status[0] || "";
return status || "";
}
function outputEntries(output) {
const entries = output?.entries;
if (!entries) return [];
if (Array.isArray(entries) && entries.length === 1 && Array.isArray(entries[0])) return entries[0];
return asArray(entries);
}
function entryLabel(entry) {
const index = entry?.index ?? "?";
const id = entry?.id ? ` ${entry.id}` : "";
const image = entry?.has_image ? " image" : " value";
const shape = Array.isArray(entry?.shape) && entry.shape.length >= 2 ? ` ${entry.shape[1]}x${entry.shape[0]}` : "";
return `#${index}${id}${image}${shape}`.trim();
}
function setStatus(node, status) {
if (!node._sxcpAccumulatorStatusWidget) return;
node._sxcpAccumulatorStatusWidget.value = status || "no accumulator data";
node.setDirtyCanvas?.(true, true);
}
function setEntries(node, entries, status = "") {
entries = asArray(entries);
entryCache.set(nodeKey(node), entries);
node._sxcpAccumulatorEntries = entries;
if (node._sxcpEntrySelectWidget) {
const labels = entries.map(entryLabel);
node._sxcpEntrySelectWidget.options.values = labels.length ? labels : ["no entries"];
if (!labels.includes(node._sxcpEntrySelectWidget.value)) {
node._sxcpEntrySelectWidget.value = labels[0] || "no entries";
}
}
setStatus(node, status || `${entries.length} entries`);
resizeNode(node);
}
function selectedEntry(node) {
const entries = entryCache.get(nodeKey(node)) || node._sxcpAccumulatorEntries || [];
const selected = widget(node, "selected_entry")?.value || node._sxcpEntrySelectWidget?.value || "";
const labels = entries.map(entryLabel);
const index = labels.indexOf(selected);
return index >= 0 ? entries[index] : entries[0];
}
function storeKey(node) {
return String(widget(node, "store_key")?.value || "").trim();
}
async function postJson(path, payload) {
const response = await api.fetchApi(path, {
method: "POST",
headers: {"Content-Type": "application/json"},
body: JSON.stringify(payload),
});
const data = await response.json();
if (!response.ok) throw new Error(data?.error || response.statusText);
return data;
}
async function refreshEntries(node) {
const key = storeKey(node);
if (!key) {
alert("Set the same explicit store_key on the Accumulator and Accumulator Preview first.");
return;
}
try {
const data = await postJson("/sxcp/accumulator/list", {store_key: key});
setEntries(node, data.entries || [], data.status || "");
} catch (err) {
console.error(`[${EXTENSION}] refresh failed`, err);
alert(`Refresh failed: ${err}`);
}
}
async function deleteSelected(node) {
const key = storeKey(node);
if (!key) {
alert("Set the same explicit store_key on the Accumulator and Accumulator Preview first.");
return;
}
const entry = selectedEntry(node);
if (!entry) {
alert("No accumulator entry selected.");
return;
}
const label = entryLabel(entry);
if (!confirm(`Delete accumulator entry ${label}?`)) return;
try {
const data = await postJson("/sxcp/accumulator/delete", {
store_key: key,
entry_id: entry.id || "",
index: entry.id ? 0 : entry.index,
clear: false,
});
setEntries(node, data.entries || [], `${data.status || ""}; deleted=${data.removed || 0}; rerun preview to refresh images`);
} catch (err) {
console.error(`[${EXTENSION}] delete failed`, err);
alert(`Delete failed: ${err}`);
}
}
async function clearStore(node) {
const key = storeKey(node);
if (!key) {
alert("Set the same explicit store_key on the Accumulator and Accumulator Preview first.");
return;
}
if (!confirm(`Clear all entries from accumulator "${key}"?`)) return;
try {
const data = await postJson("/sxcp/accumulator/delete", {store_key: key, clear: true});
setEntries(node, data.entries || [], `${data.status || ""}; cleared=${data.removed || 0}; rerun preview to refresh images`);
} catch (err) {
console.error(`[${EXTENSION}] clear failed`, err);
alert(`Clear failed: ${err}`);
}
}
function setupNode(node) {
hideWidget(widget(node, "delete_action"));
hideWidget(widget(node, "delete_entry_id"));
hideWidget(widget(node, "delete_index"));
if (!node._sxcpEntrySelectWidget) {
node._sxcpEntrySelectWidget = node.addWidget("combo", "selected_entry", "no entries", () => {}, {values: ["no entries"]});
node._sxcpEntrySelectWidget.serialize = false;
}
if (!node._sxcpAccumulatorStatusWidget) {
node._sxcpAccumulatorStatusWidget = node.addWidget("text", "accumulator_status", "no accumulator data", () => {});
node._sxcpAccumulatorStatusWidget.serialize = false;
}
if (!node._sxcpDeleteSelectedButton) {
node._sxcpDeleteSelectedButton = node.addWidget("button", "Delete Selected Entry", null, () => deleteSelected(node));
}
if (!node._sxcpClearButton) {
node._sxcpClearButton = node.addWidget("button", "Clear Accumulator", null, () => clearStore(node));
}
if (!node._sxcpRefreshButton) {
node._sxcpRefreshButton = node.addWidget("button", "Refresh Entry List", null, () => refreshEntries(node));
}
resizeNode(node);
}
app.registerExtension({
name: EXTENSION,
async setup() {
api.addEventListener("executed", ({detail}) => {
const node = getNodeById(detail?.node);
if (!isAccumulatorPreviewNode(node)) return;
const output = detail?.output || {};
setEntries(node, outputEntries(output), outputStatus(output));
});
},
async beforeRegisterNodeDef(nodeType, nodeData) {
if (nodeData.name !== NODE_NAME) return;
const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
const result = onNodeCreated?.apply(this, arguments);
setupNode(this);
return result;
};
const onConfigure = nodeType.prototype.onConfigure;
nodeType.prototype.onConfigure = function () {
const result = onConfigure?.apply(this, arguments);
queueMicrotask(() => setupNode(this));
return result;
};
},
});