from __future__ import annotations from typing import Any try: from comfy_execution.graph import ExecutionBlocker from comfy_execution.graph_utils import GraphBuilder, is_link except Exception: # Allows local syntax/import checks outside ComfyUI. ExecutionBlocker = None GraphBuilder = None def is_link(value: Any) -> bool: return isinstance(value, list) and len(value) == 2 try: from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS except Exception: ALL_NODE_CLASS_MAPPINGS = {} MAX_LOOP_VALUES = 20 MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2 COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"] class AnyType(str): def __ne__(self, _other: object) -> bool: return False ANY_TYPE = AnyType("*") def _require_graph_builder() -> None: if GraphBuilder is None: raise RuntimeError("SxCP loop nodes require ComfyUI's comfy_execution GraphBuilder.") def _execution_blocker() -> Any: return ExecutionBlocker(None) if ExecutionBlocker is not None else None def _torch_cat(first: Any, second: Any) -> Any | None: try: import torch except Exception: return None if torch.is_tensor(first) and torch.is_tensor(second): return torch.cat((first, second), dim=0) return None def _latent_cat(first: Any, second: Any) -> Any | None: if not isinstance(first, dict) or not isinstance(second, dict): return None if "samples" not in first or "samples" not in second: return None samples = _torch_cat(first["samples"], second["samples"]) if samples is None: return None merged = dict(second) merged["samples"] = samples return merged def _as_list(collection: Any) -> list[Any]: if collection is None: return [] return list(collection) if isinstance(collection, list) else [collection] def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch", skip_none: bool = True) -> Any: if value is None and skip_none: return collection mode = mode if mode in COLLECTION_MODES else "auto_batch" if mode == "string_lines": value_text = "" if value is None else str(value) if not collection: return value_text return f"{collection}\n{value_text}" if mode == "list": return _as_list(collection) + [value] if collection is None: return value if mode in ("auto_batch", "image_batch"): tensor_batch = _torch_cat(collection, value) if tensor_batch is not None: return tensor_batch if mode == "image_batch": return _as_list(collection) + [value] if mode in ("auto_batch", "latent_batch"): latent_batch = _latent_cat(collection, value) if latent_batch is not None: return latent_batch if mode == "latent_batch": return _as_list(collection) + [value] return _as_list(collection) + [value] class SxCPWhileLoopStart: @classmethod def INPUT_TYPES(cls): inputs = { "required": { "condition": ("BOOLEAN", {"default": True}), }, "optional": {}, } for index in range(MAX_LOOP_VALUES): inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) return inputs RETURN_TYPES = tuple(["FLOW_CONTROL"] + [ANY_TYPE] * MAX_LOOP_VALUES) RETURN_NAMES = tuple(["flow"] + [f"value{index}" for index in range(MAX_LOOP_VALUES)]) FUNCTION = "open" CATEGORY = "prompt_builder/loop" def open(self, condition, **kwargs): values = [] for index in range(MAX_LOOP_VALUES): values.append(kwargs.get(f"initial_value{index}") if condition else _execution_blocker()) return tuple(["stub"] + values) class SxCPWhileLoopEnd: @classmethod def INPUT_TYPES(cls): inputs = { "required": { "flow": ("FLOW_CONTROL", {"rawLink": True}), "condition": ("BOOLEAN", {}), }, "optional": {}, "hidden": { "dynprompt": "DYNPROMPT", "unique_id": "UNIQUE_ID", "extra_pnginfo": "EXTRA_PNGINFO", }, } for index in range(MAX_LOOP_VALUES): inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) return inputs RETURN_TYPES = tuple([ANY_TYPE] * MAX_LOOP_VALUES) RETURN_NAMES = tuple([f"value{index}" for index in range(MAX_LOOP_VALUES)]) FUNCTION = "close" CATEGORY = "prompt_builder/loop" def _explore_dependencies(self, node_id: str, dynprompt: Any, upstream: dict[str, list[str]], parent_ids: list[str]) -> None: node_info = dynprompt.get_node(node_id) if "inputs" not in node_info: return for value in node_info["inputs"].values(): if not is_link(value): continue parent_id = value[0] display_id = dynprompt.get_display_node_id(parent_id) display_node = dynprompt.get_node(display_id) class_type = display_node["class_type"] if class_type not in ("SxCPForLoopEnd", "SxCPWhileLoopEnd"): parent_ids.append(display_id) if parent_id not in upstream: upstream[parent_id] = [] self._explore_dependencies(parent_id, dynprompt, upstream, parent_ids) upstream[parent_id].append(node_id) def _explore_output_nodes( self, dynprompt: Any, upstream: dict[str, list[str]], output_nodes: dict[str, Any], parent_ids: list[str], ) -> None: for parent_id in upstream: display_id = dynprompt.get_display_node_id(parent_id) for output_id, link in output_nodes.items(): linked_id = link[0] if linked_id in parent_ids and display_id == linked_id and output_id not in upstream[parent_id]: if "." in parent_id: parts = parent_id.split(".") parts[-1] = output_id upstream[parent_id].append(".".join(parts)) else: upstream[parent_id].append(output_id) def _collect_contained(self, node_id: str, upstream: dict[str, list[str]], contained: dict[str, bool]) -> None: if node_id not in upstream: return for child_id in upstream[node_id]: if child_id in contained: continue contained[child_id] = True self._collect_contained(child_id, upstream, contained) def close(self, flow, condition, dynprompt=None, unique_id=None, **kwargs): if not condition: return tuple(kwargs.get(f"initial_value{index}") for index in range(MAX_LOOP_VALUES)) _require_graph_builder() upstream: dict[str, list[str]] = {} parent_ids: list[str] = [] self._explore_dependencies(unique_id, dynprompt, upstream, parent_ids) parent_ids = list(set(parent_ids)) output_nodes = {} for node_id, node in dynprompt.get_original_prompt().items(): if "inputs" not in node: continue class_def = ALL_NODE_CLASS_MAPPINGS.get(node["class_type"]) if not class_def or not getattr(class_def, "OUTPUT_NODE", False): continue for value in node["inputs"].values(): if is_link(value): output_nodes[node_id] = value graph = GraphBuilder() self._explore_output_nodes(dynprompt, upstream, output_nodes, parent_ids) contained: dict[str, bool] = {} open_node = flow[0] self._collect_contained(open_node, upstream, contained) contained[unique_id] = True contained[open_node] = True for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) node.set_override_display_id(node_id) for node_id in contained: original_node = dynprompt.get_node(node_id) node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) for key, value in original_node["inputs"].items(): if is_link(value) and value[0] in contained: parent = graph.lookup_node(value[0]) node.set_input(key, parent.out(value[1])) else: node.set_input(key, value) new_open = graph.lookup_node(open_node) original_open = dynprompt.get_node(open_node) if original_open["class_type"] == "SxCPForLoopStart": new_open.set_input("initial_index", kwargs.get("initial_value0")) new_open.set_input("initial_collected", kwargs.get("initial_value1")) for carry_index in range(1, MAX_CARRY_VALUES + 1): new_open.set_input(f"initial_value{carry_index}", kwargs.get(f"initial_value{carry_index + 1}")) else: for index in range(MAX_LOOP_VALUES): new_open.set_input(f"initial_value{index}", kwargs.get(f"initial_value{index}")) my_clone = graph.lookup_node("Recurse") return { "result": tuple(my_clone.out(index) for index in range(MAX_LOOP_VALUES)), "expand": graph.finalize(), } class SxCPForLoopStart: @classmethod def INPUT_TYPES(cls): return { "required": { "total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}), "skip": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}), }, "optional": { f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1) }, "hidden": { "initial_index": (ANY_TYPE,), "initial_collected": (ANY_TYPE,), "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = tuple(["FLOW_CONTROL", "INT", ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) RETURN_NAMES = tuple(["flow", "index", "collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) FUNCTION = "start" CATEGORY = "prompt_builder/loop" def start(self, total, skip=0, initial_index=None, initial_collected=None, **kwargs): _require_graph_builder() skip = max(0, int(skip)) index = skip if initial_index is None else max(int(initial_index), skip) collected = initial_collected initial_values = { "initial_value0": index, "initial_value1": collected, } for carry_index in range(1, MAX_CARRY_VALUES + 1): initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") graph = GraphBuilder() graph.node("SxCPWhileLoopStart", condition=index < int(total), **initial_values) return { "result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]), "expand": graph.finalize(), } class SxCPLoopAppend: @classmethod def INPUT_TYPES(cls): return { "required": { "mode": (COLLECTION_MODES, {"default": "auto_batch"}), "skip_none": ("BOOLEAN", {"default": True}), }, "optional": { "collection": (ANY_TYPE,), "value": (ANY_TYPE,), }, } RETURN_TYPES = (ANY_TYPE,) RETURN_NAMES = ("collected",) FUNCTION = "append" CATEGORY = "prompt_builder/loop" def append(self, mode, skip_none, collection=None, value=None): return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),) class SxCPForLoopEnd: @classmethod def INPUT_TYPES(cls): return { "required": { "flow": ("FLOW_CONTROL", {"rawLink": True}), "collection_mode": (COLLECTION_MODES, {"default": "auto_batch"}), "skip_none": ("BOOLEAN", {"default": True}), }, "optional": { "collected": (ANY_TYPE, {"rawLink": True}), "collect_value": (ANY_TYPE, {"rawLink": True}), **{ f"initial_value{index}": (ANY_TYPE, {"rawLink": True}) for index in range(1, MAX_CARRY_VALUES + 1) }, }, "hidden": { "dynprompt": "DYNPROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "unique_id": "UNIQUE_ID", }, } RETURN_TYPES = tuple([ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) RETURN_NAMES = tuple(["collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) FUNCTION = "end" CATEGORY = "prompt_builder/loop" def end(self, flow, collection_mode, skip_none, dynprompt=None, **kwargs): _require_graph_builder() graph = GraphBuilder() loop_start = flow[0] start_node = dynprompt.get_node(loop_start) if start_node["class_type"] != "SxCPForLoopStart": raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.") total = start_node["inputs"]["total"] next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1) condition = graph.node("SxCPLoopLessThan", a=next_index.out(0), b=total) collection = kwargs.get("collected") or [loop_start, 2] collect_value = kwargs.get("collect_value") next_collection = graph.node( "SxCPLoopAppend", collection=collection, value=collect_value, mode=collection_mode, skip_none=skip_none, ) next_values = { "initial_value0": next_index.out(0), "initial_value1": next_collection.out(0), } for carry_index in range(1, MAX_CARRY_VALUES + 1): next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values) return { "result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)), "expand": graph.finalize(), } class SxCPLoopIntAdd: @classmethod def INPUT_TYPES(cls): return { "required": { "a": ("INT", {"default": 0}), "b": ("INT", {"default": 1}), } } RETURN_TYPES = ("INT",) RETURN_NAMES = ("int",) FUNCTION = "add" CATEGORY = "prompt_builder/loop/internal" def add(self, a, b): return (int(a) + int(b),) class SxCPLoopLessThan: @classmethod def INPUT_TYPES(cls): return { "required": { "a": ("INT", {"default": 0}), "b": ("INT", {"default": 1}), } } RETURN_TYPES = ("BOOLEAN",) RETURN_NAMES = ("boolean",) FUNCTION = "compare" CATEGORY = "prompt_builder/loop/internal" def compare(self, a, b): return (int(a) < int(b),) LOOP_NODE_CLASS_MAPPINGS = { "SxCPWhileLoopStart": SxCPWhileLoopStart, "SxCPWhileLoopEnd": SxCPWhileLoopEnd, "SxCPForLoopStart": SxCPForLoopStart, "SxCPForLoopEnd": SxCPForLoopEnd, "SxCPLoopAppend": SxCPLoopAppend, "SxCPLoopIntAdd": SxCPLoopIntAdd, "SxCPLoopLessThan": SxCPLoopLessThan, } LOOP_NODE_DISPLAY_NAME_MAPPINGS = { "SxCPWhileLoopStart": "SxCP While Loop Start", "SxCPWhileLoopEnd": "SxCP While Loop End", "SxCPForLoopStart": "SxCP For Loop Start", "SxCPForLoopEnd": "SxCP For Loop End", "SxCPLoopAppend": "SxCP Loop Append", "SxCPLoopIntAdd": "SxCP Loop Int Add", "SxCPLoopLessThan": "SxCP Loop Less Than", }