Add accumulator preview and batch save node
This commit is contained in:
+317
-13
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
@@ -18,6 +21,23 @@ try:
|
||||
except Exception:
|
||||
ALL_NODE_CLASS_MAPPINGS = {}
|
||||
|
||||
try:
|
||||
import folder_paths
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from comfy.cli_args import args
|
||||
except Exception:
|
||||
folder_paths = None
|
||||
np = None
|
||||
Image = None
|
||||
PngInfo = None
|
||||
|
||||
class _ArgsFallback:
|
||||
disable_metadata = True
|
||||
|
||||
args = _ArgsFallback()
|
||||
|
||||
|
||||
MAX_LOOP_VALUES = 20
|
||||
MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2
|
||||
@@ -25,6 +45,7 @@ COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string
|
||||
ACCUMULATOR_ACTIONS = ["append_variant", "replace_by_entry_id", "append", "clear_then_append", "clear", "read"]
|
||||
ACCUMULATOR_IMAGE_BATCH_MODES = ["same_size_only", "resize_to_first"]
|
||||
ACCUMULATOR_IMAGE_GROUPS = 4
|
||||
ACCUMULATOR_PREVIEW_DELETE_ACTIONS = ["none", "delete_entry_id", "delete_index", "clear"]
|
||||
|
||||
_ACCUMULATOR_STORES: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
@@ -154,6 +175,197 @@ def _group_image_batches(images: list[Any]) -> list[Any]:
|
||||
return [batch for batch in batches if batch is not None]
|
||||
|
||||
|
||||
def _accumulator_store_key(store_key: str, unique_id: Any = None) -> str:
|
||||
key = str(store_key or "").strip()
|
||||
if key:
|
||||
return key
|
||||
return f"node:{unique_id}"
|
||||
|
||||
|
||||
def _entry_value_summary(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
text = str(value)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text[:160]
|
||||
|
||||
|
||||
def _entry_infos(store: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
entries = []
|
||||
for index, entry in enumerate(store, start=1):
|
||||
image = entry.get("image")
|
||||
shape = _image_shape(image)
|
||||
entries.append(
|
||||
{
|
||||
"index": index,
|
||||
"id": str(entry.get("id") or ""),
|
||||
"has_image": image is not None,
|
||||
"shape": list(shape) if shape is not None else [],
|
||||
"value": _entry_value_summary(entry.get("value")),
|
||||
}
|
||||
)
|
||||
return entries
|
||||
|
||||
|
||||
def _accumulator_status(key: str, store: list[dict[str, Any]]) -> str:
|
||||
images = [entry.get("image") for entry in store if entry.get("image") is not None]
|
||||
shapes = []
|
||||
for image in images:
|
||||
shape = _image_shape(image)
|
||||
if shape is not None and shape not in shapes:
|
||||
shapes.append(shape)
|
||||
shape_text = ", ".join(f"{shape[1]}x{shape[0]}" for shape in shapes) or "no images"
|
||||
return f"key={key}; entries={len(store)}; image_entries={len(images)}; formats={shape_text}"
|
||||
|
||||
|
||||
def accumulator_list_entries(store_key: str) -> dict[str, Any]:
|
||||
key = str(store_key or "").strip()
|
||||
if not key:
|
||||
raise ValueError("store_key is required for accumulator preview actions")
|
||||
store = _ACCUMULATOR_STORES.setdefault(key, [])
|
||||
return {
|
||||
"store_key": key,
|
||||
"entries": _entry_infos(store),
|
||||
"count": len(store),
|
||||
"status": _accumulator_status(key, store),
|
||||
}
|
||||
|
||||
|
||||
def accumulator_delete_entries(
|
||||
store_key: str,
|
||||
entry_id: str = "",
|
||||
index: int = 0,
|
||||
clear: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
key = str(store_key or "").strip()
|
||||
if not key:
|
||||
raise ValueError("store_key is required for accumulator preview actions")
|
||||
store = _ACCUMULATOR_STORES.setdefault(key, [])
|
||||
removed = 0
|
||||
if clear:
|
||||
removed = len(store)
|
||||
store.clear()
|
||||
else:
|
||||
entry_id = str(entry_id or "").strip()
|
||||
if entry_id:
|
||||
before = len(store)
|
||||
store[:] = [entry for entry in store if str(entry.get("id") or "") != entry_id]
|
||||
removed = before - len(store)
|
||||
elif int(index) > 0:
|
||||
zero_index = int(index) - 1
|
||||
if zero_index < len(store):
|
||||
del store[zero_index]
|
||||
removed = 1
|
||||
else:
|
||||
raise ValueError("entry_id or 1-based index is required")
|
||||
result = accumulator_list_entries(key)
|
||||
result["removed"] = removed
|
||||
return result
|
||||
|
||||
|
||||
def _require_image_saving() -> None:
|
||||
if folder_paths is None or np is None or Image is None:
|
||||
raise RuntimeError("Image preview/save helpers require ComfyUI image dependencies.")
|
||||
|
||||
|
||||
def _metadata(prompt: Any, extra_pnginfo: Any) -> Any:
|
||||
if args.disable_metadata or PngInfo is None:
|
||||
return None
|
||||
metadata = PngInfo()
|
||||
if prompt is not None:
|
||||
metadata.add_text("prompt", json.dumps(prompt))
|
||||
if extra_pnginfo is not None:
|
||||
for key, value in extra_pnginfo.items():
|
||||
metadata.add_text(key, json.dumps(value))
|
||||
return metadata
|
||||
|
||||
|
||||
def _image_to_pil(image: Any) -> Any:
|
||||
_require_image_saving()
|
||||
tensor = image
|
||||
try:
|
||||
if len(tensor.shape) == 4:
|
||||
tensor = tensor[0]
|
||||
except Exception:
|
||||
pass
|
||||
image_data = 255.0 * tensor.cpu().numpy()
|
||||
return Image.fromarray(np.clip(image_data, 0, 255).astype(np.uint8))
|
||||
|
||||
|
||||
def _safe_filename_prefix(prefix: str, default: str = "sxcp_accumulator") -> str:
|
||||
text = str(prefix or "").strip() or default
|
||||
text = os.path.basename(text)
|
||||
text = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("._")
|
||||
return text or default
|
||||
|
||||
|
||||
def _next_save_counter(folder: str, prefix: str) -> int:
|
||||
pattern = re.compile(rf"^{re.escape(prefix)}_(\d+)_\.png$")
|
||||
counter = 1
|
||||
try:
|
||||
for filename in os.listdir(folder):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
counter = max(counter, int(match.group(1)) + 1)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return counter
|
||||
|
||||
|
||||
def _preview_image_results(images: list[Any], preview_limit: int, prompt: Any, extra_pnginfo: Any) -> list[dict[str, str]]:
|
||||
if not images:
|
||||
return []
|
||||
_require_image_saving()
|
||||
output_dir = folder_paths.get_temp_directory()
|
||||
preview_images = images[: max(1, int(preview_limit))]
|
||||
first_shape = _image_shape(preview_images[0])
|
||||
height, width = (first_shape[0], first_shape[1]) if first_shape else (512, 512)
|
||||
prefix = "SxCPAccumulatorPreview_temp_" + "".join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5))
|
||||
full_output_folder, filename, counter, subfolder, _filename_prefix = folder_paths.get_save_image_path(prefix, output_dir, width, height)
|
||||
metadata = _metadata(prompt, extra_pnginfo)
|
||||
results = []
|
||||
for image in preview_images:
|
||||
file = f"{filename}_{counter:05}_.png"
|
||||
_image_to_pil(image).save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=1)
|
||||
results.append({"filename": file, "subfolder": subfolder, "type": "temp"})
|
||||
counter += 1
|
||||
return results
|
||||
|
||||
|
||||
def _resolve_save_folder(save_path: str) -> str:
|
||||
_require_image_saving()
|
||||
raw_path = os.path.expanduser(str(save_path or "").strip())
|
||||
if not raw_path:
|
||||
raw_path = "sxcp_accumulator"
|
||||
if os.path.isabs(raw_path):
|
||||
return raw_path
|
||||
return os.path.join(folder_paths.get_output_directory(), raw_path)
|
||||
|
||||
|
||||
def _save_images_to_folder(
|
||||
images: list[Any],
|
||||
save_path: str,
|
||||
filename_prefix: str,
|
||||
prompt: Any,
|
||||
extra_pnginfo: Any,
|
||||
) -> list[str]:
|
||||
if not images:
|
||||
return []
|
||||
folder = _resolve_save_folder(save_path)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
prefix = _safe_filename_prefix(filename_prefix)
|
||||
counter = _next_save_counter(folder, prefix)
|
||||
metadata = _metadata(prompt, extra_pnginfo)
|
||||
saved_paths = []
|
||||
for image in images:
|
||||
file = f"{prefix}_{counter:05}_.png"
|
||||
path = os.path.join(folder, file)
|
||||
_image_to_pil(image).save(path, pnginfo=metadata, compress_level=4)
|
||||
saved_paths.append(path)
|
||||
counter += 1
|
||||
return saved_paths
|
||||
|
||||
|
||||
def _as_list(collection: Any) -> list[Any]:
|
||||
if collection is None:
|
||||
return []
|
||||
@@ -449,8 +661,7 @@ class SxCPAccumulator:
|
||||
return random.random()
|
||||
|
||||
def _store_key(self, store_key: str, unique_id: Any) -> str:
|
||||
key = str(store_key or "").strip()
|
||||
return key or f"node:{unique_id}"
|
||||
return _accumulator_store_key(store_key, unique_id)
|
||||
|
||||
def _entry_id(self, entry_id: Any, entry_tag: str, image_index: int, image_count: int) -> str:
|
||||
text = "" if entry_id is None else str(entry_id).strip()
|
||||
@@ -523,18 +734,8 @@ class SxCPAccumulator:
|
||||
return collection
|
||||
|
||||
def _status(self, key: str, store: list[dict[str, Any]], image_batch: Any, image_batches: list[Any]) -> str:
|
||||
images = [entry.get("image") for entry in store if entry.get("image") is not None]
|
||||
shapes = []
|
||||
for image in images:
|
||||
shape = _image_shape(image)
|
||||
if shape is not None and shape not in shapes:
|
||||
shapes.append(shape)
|
||||
shape_text = ", ".join(f"{shape[1]}x{shape[0]}" for shape in shapes) or "no images"
|
||||
batch_state = "all images batched" if image_batch is not None else "mixed sizes or no image batch"
|
||||
return (
|
||||
f"key={key}; entries={len(store)}; image_entries={len(images)}; "
|
||||
f"formats={shape_text}; grouped_batches={len(image_batches)}; {batch_state}"
|
||||
)
|
||||
return f"{_accumulator_status(key, store)}; grouped_batches={len(image_batches)}; {batch_state}"
|
||||
|
||||
def accumulate(
|
||||
self,
|
||||
@@ -573,6 +774,107 @@ class SxCPAccumulator:
|
||||
return tuple([self._collection(store), image_batch, images] + grouped_outputs + [len(store), status])
|
||||
|
||||
|
||||
class SxCPAccumulatorPreview:
|
||||
OUTPUT_NODE = True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"store_key": ("STRING", {"default": "", "multiline": False}),
|
||||
"preview_limit": ("INT", {"default": 64, "min": 1, "max": 512, "step": 1}),
|
||||
"delete_action": (ACCUMULATOR_PREVIEW_DELETE_ACTIONS, {"default": "none"}),
|
||||
"delete_entry_id": ("STRING", {"default": "", "multiline": False}),
|
||||
"delete_index": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}),
|
||||
"save_batch": ("BOOLEAN", {"default": False}),
|
||||
"finished": ("BOOLEAN", {"default": True}),
|
||||
"save_path": ("STRING", {"default": "sxcp_accumulator", "multiline": False}),
|
||||
"filename_prefix": ("STRING", {"default": "sxcp_accum", "multiline": False}),
|
||||
"clear_after_save": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"hidden": {
|
||||
"prompt": "PROMPT",
|
||||
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT", "STRING", "STRING")
|
||||
RETURN_NAMES = ("count", "status", "saved_paths_json")
|
||||
FUNCTION = "preview"
|
||||
CATEGORY = "prompt_builder/loop"
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs):
|
||||
return random.random()
|
||||
|
||||
def _delete_from_inputs(self, key: str, delete_action: str, delete_entry_id: str, delete_index: int) -> int:
|
||||
action = delete_action if delete_action in ACCUMULATOR_PREVIEW_DELETE_ACTIONS else "none"
|
||||
if action == "none":
|
||||
return 0
|
||||
if action == "clear":
|
||||
return int(accumulator_delete_entries(key, clear=True).get("removed", 0))
|
||||
if action == "delete_entry_id":
|
||||
entry_id = str(delete_entry_id or "").strip()
|
||||
if not entry_id:
|
||||
return 0
|
||||
return int(accumulator_delete_entries(key, entry_id=entry_id).get("removed", 0))
|
||||
if action == "delete_index":
|
||||
index = int(delete_index)
|
||||
if index <= 0:
|
||||
return 0
|
||||
return int(accumulator_delete_entries(key, index=index).get("removed", 0))
|
||||
return 0
|
||||
|
||||
def preview(
|
||||
self,
|
||||
store_key,
|
||||
preview_limit,
|
||||
delete_action,
|
||||
delete_entry_id,
|
||||
delete_index,
|
||||
save_batch,
|
||||
finished,
|
||||
save_path,
|
||||
filename_prefix,
|
||||
clear_after_save,
|
||||
prompt=None,
|
||||
extra_pnginfo=None,
|
||||
unique_id=None,
|
||||
):
|
||||
key = _accumulator_store_key(store_key, unique_id)
|
||||
store = _ACCUMULATOR_STORES.setdefault(key, [])
|
||||
removed = self._delete_from_inputs(key, delete_action, delete_entry_id, delete_index)
|
||||
images = [entry["image"] for entry in store if entry.get("image") is not None]
|
||||
|
||||
saved_paths: list[str] = []
|
||||
save_status = ""
|
||||
if bool(save_batch) and bool(finished):
|
||||
saved_paths = _save_images_to_folder(images, save_path, filename_prefix, prompt, extra_pnginfo)
|
||||
save_status = f"; saved={len(saved_paths)}"
|
||||
if saved_paths and bool(clear_after_save):
|
||||
store.clear()
|
||||
images = []
|
||||
save_status += "; cleared_after_save"
|
||||
|
||||
preview_images = _preview_image_results(images, preview_limit, prompt, extra_pnginfo)
|
||||
entries = _entry_infos(store)
|
||||
status = _accumulator_status(key, store)
|
||||
if removed:
|
||||
status += f"; removed={removed}"
|
||||
status += save_status
|
||||
saved_json = json.dumps(saved_paths, ensure_ascii=True)
|
||||
return {
|
||||
"ui": {
|
||||
"images": preview_images,
|
||||
"entries": entries,
|
||||
"status": [status],
|
||||
"saved_paths": saved_paths,
|
||||
},
|
||||
"result": (len(store), status, saved_json),
|
||||
}
|
||||
|
||||
|
||||
class SxCPForLoopEnd:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@@ -697,6 +999,7 @@ LOOP_NODE_CLASS_MAPPINGS = {
|
||||
"SxCPForLoopEnd": SxCPForLoopEnd,
|
||||
"SxCPLoopAppend": SxCPLoopAppend,
|
||||
"SxCPAccumulator": SxCPAccumulator,
|
||||
"SxCPAccumulatorPreview": SxCPAccumulatorPreview,
|
||||
"SxCPLoopIntAdd": SxCPLoopIntAdd,
|
||||
"SxCPLoopLessThan": SxCPLoopLessThan,
|
||||
"SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual,
|
||||
@@ -709,6 +1012,7 @@ LOOP_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SxCPForLoopEnd": "SxCP For Loop End",
|
||||
"SxCPLoopAppend": "SxCP Loop Append",
|
||||
"SxCPAccumulator": "SxCP Accumulator",
|
||||
"SxCPAccumulatorPreview": "SxCP Accumulator Preview",
|
||||
"SxCPLoopIntAdd": "SxCP Loop Int Add",
|
||||
"SxCPLoopLessThan": "SxCP Loop Less Than",
|
||||
"SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal",
|
||||
|
||||
Reference in New Issue
Block a user