from __future__ import annotations import random from typing import Any try: from comfy_execution.graph import ExecutionBlocker from comfy_execution.graph_utils import GraphBuilder, is_link except Exception: # Allows local syntax/import checks outside ComfyUI. ExecutionBlocker = None GraphBuilder = None def is_link(value: Any) -> bool: return isinstance(value, list) and len(value) == 2 try: from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS except Exception: ALL_NODE_CLASS_MAPPINGS = {} MAX_LOOP_VALUES = 20 MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2 COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"] ACCUMULATOR_ACTIONS = ["replace_by_entry_id", "append", "clear_then_append", "clear", "read"] ACCUMULATOR_IMAGE_BATCH_MODES = ["same_size_only", "resize_to_first"] ACCUMULATOR_IMAGE_GROUPS = 4 _ACCUMULATOR_STORES: dict[str, list[dict[str, Any]]] = {} class AnyType(str): def __ne__(self, _other: object) -> bool: return False ANY_TYPE = AnyType("*") def _require_graph_builder() -> None: if GraphBuilder is None: raise RuntimeError("SxCP loop nodes require ComfyUI's comfy_execution GraphBuilder.") def _execution_blocker() -> Any: return ExecutionBlocker(None) if ExecutionBlocker is not None else None def _torch_cat(first: Any, second: Any) -> Any | None: try: import torch except Exception: return None if torch.is_tensor(first) and torch.is_tensor(second): return torch.cat((first, second), dim=0) return None def _latent_cat(first: Any, second: Any) -> Any | None: if not isinstance(first, dict) or not isinstance(second, dict): return None if "samples" not in first or "samples" not in second: return None samples = _torch_cat(first["samples"], second["samples"]) if samples is None: return None merged = dict(second) merged["samples"] = samples return merged def _torch_cat_many(values: list[Any]) -> Any | None: if not values: return None result = values[0] for value in values[1:]: result = _torch_cat(result, value) if result is None: return None return result def _is_image_tensor(value: Any) -> bool: try: import torch except Exception: return False return torch.is_tensor(value) and len(value.shape) == 4 def _image_shape(value: Any) -> tuple[int, ...] | None: if not _is_image_tensor(value): return None return tuple(int(part) for part in value.shape[1:]) def _split_image_value(value: Any) -> list[Any]: if value is None: return [] if isinstance(value, (list, tuple)): images: list[Any] = [] for item in value: images.extend(_split_image_value(item)) return images if not _is_image_tensor(value): return [] if int(value.shape[0]) <= 1: return [value] return [value[index : index + 1] for index in range(int(value.shape[0]))] def _resize_image_to_shape(image: Any, shape: tuple[int, ...]) -> Any | None: if not _is_image_tensor(image): return None try: import comfy.utils except Exception: return None height, width = int(shape[0]), int(shape[1]) return comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) def _image_batch_from_images(images: list[Any], mode: str = "same_size_only") -> Any | None: if not images: return None mode = mode if mode in ACCUMULATOR_IMAGE_BATCH_MODES else "same_size_only" first_shape = _image_shape(images[0]) if first_shape is None: return None normalized = [] for image in images: if _image_shape(image) != first_shape: if mode != "resize_to_first": return None image = _resize_image_to_shape(image, first_shape) if image is None: return None normalized.append(image) return _torch_cat_many(normalized) def _group_image_batches(images: list[Any]) -> list[Any]: grouped: dict[tuple[int, ...], list[Any]] = {} order: list[tuple[int, ...]] = [] for image in images: shape = _image_shape(image) if shape is None: continue if shape not in grouped: grouped[shape] = [] order.append(shape) grouped[shape].append(image) batches = [_torch_cat_many(grouped[shape]) for shape in order] return [batch for batch in batches if batch is not None] def _as_list(collection: Any) -> list[Any]: if collection is None: return [] return list(collection) if isinstance(collection, list) else [collection] def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch", skip_none: bool = True) -> Any: if value is None and skip_none: return collection mode = mode if mode in COLLECTION_MODES else "auto_batch" if mode == "string_lines": value_text = "" if value is None else str(value) if not collection: return value_text return f"{collection}\n{value_text}" if mode == "list": return _as_list(collection) + [value] if collection is None: return value if mode in ("auto_batch", "image_batch"): tensor_batch = _torch_cat(collection, value) if tensor_batch is not None: return tensor_batch if mode == "image_batch": return _as_list(collection) + [value] if mode in ("auto_batch", "latent_batch"): latent_batch = _latent_cat(collection, value) if latent_batch is not None: return latent_batch if mode == "latent_batch": return _as_list(collection) + [value] return _as_list(collection) + [value] class SxCPWhileLoopStart: @classmethod def INPUT_TYPES(cls): inputs = { "required": { "condition": ("BOOLEAN", {"default": True}), }, "optional": {}, } for index in range(MAX_LOOP_VALUES): inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) return inputs RETURN_TYPES = tuple(["FLOW_CONTROL"] + [ANY_TYPE] * MAX_LOOP_VALUES) RETURN_NAMES = tuple(["flow"] + [f"value{index}" for index in range(MAX_LOOP_VALUES)]) FUNCTION = "open" CATEGORY = "prompt_builder/loop" def open(self, condition, **kwargs): values = [] for index in range(MAX_LOOP_VALUES): values.append(kwargs.get(f"initial_value{index}") if condition else _execution_blocker()) return tuple(["stub"] + values) class SxCPWhileLoopEnd: @classmethod def INPUT_TYPES(cls): inputs = { "required": { "flow": ("FLOW_CONTROL", {"rawLink": True}), "condition": ("BOOLEAN", {}), }, "optional": {}, "hidden": { "dynprompt": "DYNPROMPT", "unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO", }, } for index in range(MAX_LOOP_VALUES): inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) return inputs RETURN_TYPES = tuple([ANY_TYPE] * MAX_LOOP_VALUES) RETURN_NAMES = tuple([f"value{index}" for index in range(MAX_LOOP_VALUES)]) FUNCTION = "close" CATEGORY = "prompt_builder/loop" def _explore_dependencies(self, node_id: str, dynprompt: Any, upstream: dict[str, list[str]], parent_ids: list[str]) -> None: node_info = dynprompt.get_node(node_id) if "inputs" not in node_info: return for value in node_info["inputs"].values(): if not is_link(value): continue parent_id = value[0] display_id = dynprompt.get_display_node_id(parent_id) display_node = dynprompt.get_node(display_id) class_type = display_node["class_type"] if class_type not in ("SxCPForLoopEnd", "SxCPWhileLoopEnd"): parent_ids.append(display_id) if parent_id not in upstream: upstream[parent_id] = [] self._explore_dependencies(parent_id, dynprompt, upstream, parent_ids) upstream[parent_id].append(node_id) def _explore_output_nodes( self, dynprompt: Any, upstream: dict[str, list[str]], output_nodes: dict[str, Any], parent_ids: list[str], ) -> None: for parent_id in upstream: display_id = dynprompt.get_display_node_id(parent_id) for output_id, link in output_nodes.items(): linked_id = link[0] if linked_id in parent_ids and display_id == linked_id and output_id not in upstream[parent_id]: if "." in parent_id: parts = parent_id.split(".") parts[-1] = output_id upstream[parent_id].append(".".join(parts)) else: upstream[parent_id].append(output_id) def _collect_contained(self, node_id: str, upstream: dict[str, list[str]], contained: dict[str, bool]) -> None: if node_id not in upstream: return for child_id in upstream[node_id]: if child_id in contained: continue contained[child_id] = True self._collect_contained(child_id, upstream, contained) def close(self, flow, condition, dynprompt=None, unique_id=None, **kwargs): if not condition: return tuple(kwargs.get(f"initial_value{index}") for index in range(MAX_LOOP_VALUES)) _require_graph_builder() upstream: dict[str, list[str]] = {} parent_ids: list[str] = [] self._explore_dependencies(unique_id, dynprompt, upstream, parent_ids) parent_ids = list(set(parent_ids)) output_nodes = {} for node_id, node in dynprompt.get_original_prompt().items(): if "inputs" not in node: continue class_def = ALL_NODE_CLASS_MAPPINGS.get(node["class_type"]) if not class_def or not getattr(class_def, "OUTPUT_NODE", False): continue for value in node["inputs"].values(): if is_link(value): output_nodes[node_id] = value graph = GraphBuilder() self._explore_output_nodes(dynprompt, upstream, output_nodes, parent_ids) contained: dict[str, bool] = {} open_node = flow[0] self._collect_contained(open_node, upstream, contained) contained[unique_id] = True contained[open_node] = True for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) node.set_override_display_id(node_id) for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) for key, value in original_node["inputs"].items(): if is_link(value) and value[0] in contained: parent = graph.lookup_node(value[0]) node.set_input(key, parent.out(value[1])) else: node.set_input(key, value) new_open = graph.lookup_node(open_node) original_open = dynprompt.get_node(open_node) if original_open["class_type"] == "SxCPForLoopStart": new_open.set_input("initial_index", kwargs.get("initial_value0")) new_open.set_input("initial_collected", kwargs.get("initial_value1")) for carry_index in range(1, MAX_CARRY_VALUES + 1): new_open.set_input(f"initial_value{carry_index}", kwargs.get(f"initial_value{carry_index + 1}")) else: for index in range(MAX_LOOP_VALUES): new_open.set_input(f"initial_value{index}", kwargs.get(f"initial_value{index}")) my_clone = graph.lookup_node("Recurse") return { "result": tuple(my_clone.out(index) for index in range(MAX_LOOP_VALUES)), "expand": graph.finalize(), } class SxCPForLoopStart: @classmethod def INPUT_TYPES(cls): return { "required": { "total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}), "skip": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}), }, "optional": { f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1) }, "hidden": { "initial_index": (ANY_TYPE,), "initial_collected": (ANY_TYPE,), "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = tuple(["FLOW_CONTROL", "INT", ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) RETURN_NAMES = tuple(["flow", "index", "collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) FUNCTION = "start" CATEGORY = "prompt_builder/loop" def start(self, total, skip=0, initial_index=None, initial_collected=None, **kwargs): _require_graph_builder() total = max(1, int(total)) skip = max(0, int(skip)) first_index = skip + 1 index = first_index if initial_index is None else max(int(initial_index), first_index) collected = initial_collected initial_values = { "initial_value0": index, "initial_value1": collected, } for carry_index in range(1, MAX_CARRY_VALUES + 1): initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") graph = GraphBuilder() graph.node("SxCPWhileLoopStart", condition=index <= total, **initial_values) return { "result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]), "expand": graph.finalize(), } class SxCPLoopAppend: @classmethod def INPUT_TYPES(cls): return { "required": { "mode": (COLLECTION_MODES, {"default": "auto_batch"}), "skip_none": ("BOOLEAN", {"default": True}), }, "optional": { "collection": (ANY_TYPE,), "value": (ANY_TYPE,), }, } RETURN_TYPES = (ANY_TYPE,) RETURN_NAMES = ("collected",) FUNCTION = "append" CATEGORY = "prompt_builder/loop" def append(self, mode, skip_none, collection=None, value=None): return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),) class SxCPAccumulator: @classmethod def INPUT_TYPES(cls): return { "required": { "store_key": ("STRING", {"default": "", "multiline": False}), "action": (ACCUMULATOR_ACTIONS, {"default": "replace_by_entry_id"}), "max_items": ("INT", {"default": 32, "min": 1, "max": 10000, "step": 1}), "image_batch_mode": (ACCUMULATOR_IMAGE_BATCH_MODES, {"default": "same_size_only"}), "skip_empty": ("BOOLEAN", {"default": True}), }, "optional": { "image": ("IMAGE",), "value": (ANY_TYPE,), "entry_id": (ANY_TYPE,), }, "hidden": { "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = tuple([ANY_TYPE, "IMAGE", "IMAGE"] + ["IMAGE"] * ACCUMULATOR_IMAGE_GROUPS + ["INT", "STRING"]) RETURN_NAMES = tuple( ["collection", "image_batch", "image_list"] + [f"image_batch_{index}" for index in range(1, ACCUMULATOR_IMAGE_GROUPS + 1)] + ["count", "status"] ) OUTPUT_IS_LIST = tuple([False, False, True] + [False] * ACCUMULATOR_IMAGE_GROUPS + [False, False]) FUNCTION = "accumulate" CATEGORY = "prompt_builder/loop" @classmethod def IS_CHANGED(cls, *args, **kwargs): return random.random() def _store_key(self, store_key: str, unique_id: Any) -> str: key = str(store_key or "").strip() return key or f"node:{unique_id}" def _entry_id(self, entry_id: Any, image_index: int, image_count: int) -> str: if entry_id is None: return "" text = str(entry_id).strip() if not text: return "" if image_count <= 1: return text return f"{text}:{image_index + 1}" def _value_for_image(self, value: Any, image_index: int, image_count: int) -> Any: if image_count <= 1: return value if isinstance(value, (list, tuple)) and len(value) == image_count: return value[image_index] return value def _entry_records(self, image: Any, value: Any, entry_id: Any, skip_empty: bool) -> list[dict[str, Any]]: images = _split_image_value(image) if not images: if value is None and skip_empty: return [] return [{"id": self._entry_id(entry_id, 0, 1), "image": None, "value": value}] image_count = len(images) return [ { "id": self._entry_id(entry_id, index, image_count), "image": image_item, "value": self._value_for_image(value, index, image_count), } for index, image_item in enumerate(images) ] def _append_or_replace(self, store: list[dict[str, Any]], entries: list[dict[str, Any]], action: str) -> None: replace = action == "replace_by_entry_id" for entry in entries: entry_id = entry.get("id") or "" if replace and entry_id: for index, existing in enumerate(store): if existing.get("id") == entry_id: store[index] = entry break else: store.append(entry) else: store.append(entry) def _collection(self, store: list[dict[str, Any]]) -> list[Any]: collection = [] for entry in store: value = entry.get("value") collection.append(value if value is not None else entry.get("image")) return collection def _status(self, key: str, store: list[dict[str, Any]], image_batch: Any, image_batches: list[Any]) -> str: images = [entry.get("image") for entry in store if entry.get("image") is not None] shapes = [] for image in images: shape = _image_shape(image) if shape is not None and shape not in shapes: shapes.append(shape) shape_text = ", ".join(f"{shape[1]}x{shape[0]}" for shape in shapes) or "no images" batch_state = "all images batched" if image_batch is not None else "mixed sizes or no image batch" return ( f"key={key}; entries={len(store)}; image_entries={len(images)}; " f"formats={shape_text}; grouped_batches={len(image_batches)}; {batch_state}" ) def accumulate( self, store_key, action, max_items, image_batch_mode, skip_empty, image=None, value=None, entry_id=None, unique_id=None, ): key = self._store_key(store_key, unique_id) action = action if action in ACCUMULATOR_ACTIONS else "replace_by_entry_id" store = _ACCUMULATOR_STORES.setdefault(key, []) if action in ("clear", "clear_then_append"): store.clear() if action not in ("clear", "read"): entries = self._entry_records(image, value, entry_id, bool(skip_empty)) self._append_or_replace(store, entries, action) max_items = max(1, int(max_items)) if len(store) > max_items: del store[: len(store) - max_items] images = [entry["image"] for entry in store if entry.get("image") is not None] image_batch = _image_batch_from_images(images, image_batch_mode) image_batches = _group_image_batches(images) grouped_outputs = image_batches[:ACCUMULATOR_IMAGE_GROUPS] grouped_outputs += [None] * (ACCUMULATOR_IMAGE_GROUPS - len(grouped_outputs)) status = self._status(key, store, image_batch, image_batches) return tuple([self._collection(store), image_batch, images] + grouped_outputs + [len(store), status]) class SxCPForLoopEnd: @classmethod def INPUT_TYPES(cls): return { "required": { "flow": ("FLOW_CONTROL", {"rawLink": True}), "collection_mode": (COLLECTION_MODES, {"default": "auto_batch"}), "skip_none": ("BOOLEAN", {"default": True}), }, "optional": { "collected": (ANY_TYPE, {"rawLink": True}), "collect_value": (ANY_TYPE, {"rawLink": True}), **{ f"initial_value{index}": (ANY_TYPE, {"rawLink": True}) for index in range(1, MAX_CARRY_VALUES + 1) }, }, "hidden": { "dynprompt": "DYNPROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = tuple([ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) RETURN_NAMES = tuple(["collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) FUNCTION = "end" CATEGORY = "prompt_builder/loop" def end(self, flow, collection_mode, skip_none, dynprompt=None, **kwargs): _require_graph_builder() graph = GraphBuilder() loop_start = flow[0] start_node = dynprompt.get_node(loop_start) if start_node["class_type"] != "SxCPForLoopStart": raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.") total = start_node["inputs"]["total"] next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1) condition = graph.node("SxCPLoopLessThanOrEqual", a=next_index.out(0), b=total) collection = kwargs.get("collected") or [loop_start, 2] collect_value = kwargs.get("collect_value") next_collection = graph.node( "SxCPLoopAppend", collection=collection, value=collect_value, mode=collection_mode, skip_none=skip_none, ) next_values = { "initial_value0": next_index.out(0), "initial_value1": next_collection.out(0), } for carry_index in range(1, MAX_CARRY_VALUES + 1): next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values) return { "result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)), "expand": graph.finalize(), } class SxCPLoopIntAdd: @classmethod def INPUT_TYPES(cls): return { "required": { "a": ("INT", {"default": 0}), "b": ("INT", {"default": 1}), } } RETURN_TYPES = ("INT",) RETURN_NAMES = ("int",) FUNCTION = "add" CATEGORY = "prompt_builder/loop/internal" def add(self, a, b): return (int(a) + int(b),) class SxCPLoopLessThan: @classmethod def INPUT_TYPES(cls): return { "required": { "a": ("INT", {"default": 0}), "b": ("INT", {"default": 1}), } } RETURN_TYPES = ("BOOLEAN",) RETURN_NAMES = ("boolean",) FUNCTION = "compare" CATEGORY = "prompt_builder/loop/internal" def compare(self, a, b): return (int(a) < int(b),) class SxCPLoopLessThanOrEqual: @classmethod def INPUT_TYPES(cls): return { "required": { "a": ("INT", {"default": 0}), "b": ("INT", {"default": 0}), } } RETURN_TYPES = ("BOOLEAN",) FUNCTION = "compare" CATEGORY = "prompt_builder/loop/internal" def compare(self, a, b): return (int(a) <= int(b),) LOOP_NODE_CLASS_MAPPINGS = { "SxCPWhileLoopStart": SxCPWhileLoopStart, "SxCPWhileLoopEnd": SxCPWhileLoopEnd, "SxCPForLoopStart": SxCPForLoopStart, "SxCPForLoopEnd": SxCPForLoopEnd, "SxCPLoopAppend": SxCPLoopAppend, "SxCPAccumulator": SxCPAccumulator, "SxCPLoopIntAdd": SxCPLoopIntAdd, "SxCPLoopLessThan": SxCPLoopLessThan, "SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual, } LOOP_NODE_DISPLAY_NAME_MAPPINGS = { "SxCPWhileLoopStart": "SxCP While Loop Start", "SxCPWhileLoopEnd": "SxCP While Loop End", "SxCPForLoopStart": "SxCP For Loop Start", "SxCPForLoopEnd": "SxCP For Loop End", "SxCPLoopAppend": "SxCP Loop Append", "SxCPAccumulator": "SxCP Accumulator", "SxCPLoopIntAdd": "SxCP Loop Int Add", "SxCPLoopLessThan": "SxCP Loop Less Than", "SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal", }