Files
ComfyUI-Ethanfel-Prompt-Bui…/loop_nodes.py
T
2026-06-24 22:16:50 +02:00

462 lines
16 KiB
Python

from __future__ import annotations
from typing import Any
try:
from comfy_execution.graph import ExecutionBlocker
from comfy_execution.graph_utils import GraphBuilder, is_link
except Exception: # Allows local syntax/import checks outside ComfyUI.
ExecutionBlocker = None
GraphBuilder = None
def is_link(value: Any) -> bool:
return isinstance(value, list) and len(value) == 2
try:
from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS
except Exception:
ALL_NODE_CLASS_MAPPINGS = {}
MAX_LOOP_VALUES = 20
MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2
COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"]
class AnyType(str):
def __ne__(self, _other: object) -> bool:
return False
ANY_TYPE = AnyType("*")
def _require_graph_builder() -> None:
if GraphBuilder is None:
raise RuntimeError("SxCP loop nodes require ComfyUI's comfy_execution GraphBuilder.")
def _execution_blocker() -> Any:
return ExecutionBlocker(None) if ExecutionBlocker is not None else None
def _torch_cat(first: Any, second: Any) -> Any | None:
try:
import torch
except Exception:
return None
if torch.is_tensor(first) and torch.is_tensor(second):
return torch.cat((first, second), dim=0)
return None
def _latent_cat(first: Any, second: Any) -> Any | None:
if not isinstance(first, dict) or not isinstance(second, dict):
return None
if "samples" not in first or "samples" not in second:
return None
samples = _torch_cat(first["samples"], second["samples"])
if samples is None:
return None
merged = dict(second)
merged["samples"] = samples
return merged
def _as_list(collection: Any) -> list[Any]:
if collection is None:
return []
return list(collection) if isinstance(collection, list) else [collection]
def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch", skip_none: bool = True) -> Any:
if value is None and skip_none:
return collection
mode = mode if mode in COLLECTION_MODES else "auto_batch"
if mode == "string_lines":
value_text = "" if value is None else str(value)
if not collection:
return value_text
return f"{collection}\n{value_text}"
if mode == "list":
return _as_list(collection) + [value]
if collection is None:
return value
if mode in ("auto_batch", "image_batch"):
tensor_batch = _torch_cat(collection, value)
if tensor_batch is not None:
return tensor_batch
if mode == "image_batch":
return _as_list(collection) + [value]
if mode in ("auto_batch", "latent_batch"):
latent_batch = _latent_cat(collection, value)
if latent_batch is not None:
return latent_batch
if mode == "latent_batch":
return _as_list(collection) + [value]
return _as_list(collection) + [value]
class SxCPWhileLoopStart:
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"condition": ("BOOLEAN", {"default": True}),
},
"optional": {},
}
for index in range(MAX_LOOP_VALUES):
inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,)
return inputs
RETURN_TYPES = tuple(["FLOW_CONTROL"] + [ANY_TYPE] * MAX_LOOP_VALUES)
RETURN_NAMES = tuple(["flow"] + [f"value{index}" for index in range(MAX_LOOP_VALUES)])
FUNCTION = "open"
CATEGORY = "prompt_builder/loop"
def open(self, condition, **kwargs):
values = []
for index in range(MAX_LOOP_VALUES):
values.append(kwargs.get(f"initial_value{index}") if condition else _execution_blocker())
return tuple(["stub"] + values)
class SxCPWhileLoopEnd:
@classmethod
def INPUT_TYPES(cls):
inputs = {
"required": {
"flow": ("FLOW_CONTROL", {"rawLink": True}),
"condition": ("BOOLEAN", {}),
},
"optional": {},
"hidden": {
"dynprompt": "DYNPROMPT",
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
for index in range(MAX_LOOP_VALUES):
inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,)
return inputs
RETURN_TYPES = tuple([ANY_TYPE] * MAX_LOOP_VALUES)
RETURN_NAMES = tuple([f"value{index}" for index in range(MAX_LOOP_VALUES)])
FUNCTION = "close"
CATEGORY = "prompt_builder/loop"
def _explore_dependencies(self, node_id: str, dynprompt: Any, upstream: dict[str, list[str]], parent_ids: list[str]) -> None:
node_info = dynprompt.get_node(node_id)
if "inputs" not in node_info:
return
for value in node_info["inputs"].values():
if not is_link(value):
continue
parent_id = value[0]
display_id = dynprompt.get_display_node_id(parent_id)
display_node = dynprompt.get_node(display_id)
class_type = display_node["class_type"]
if class_type not in ("SxCPForLoopEnd", "SxCPWhileLoopEnd"):
parent_ids.append(display_id)
if parent_id not in upstream:
upstream[parent_id] = []
self._explore_dependencies(parent_id, dynprompt, upstream, parent_ids)
upstream[parent_id].append(node_id)
def _explore_output_nodes(
self,
dynprompt: Any,
upstream: dict[str, list[str]],
output_nodes: dict[str, Any],
parent_ids: list[str],
) -> None:
for parent_id in upstream:
display_id = dynprompt.get_display_node_id(parent_id)
for output_id, link in output_nodes.items():
linked_id = link[0]
if linked_id in parent_ids and display_id == linked_id and output_id not in upstream[parent_id]:
if "." in parent_id:
parts = parent_id.split(".")
parts[-1] = output_id
upstream[parent_id].append(".".join(parts))
else:
upstream[parent_id].append(output_id)
def _collect_contained(self, node_id: str, upstream: dict[str, list[str]], contained: dict[str, bool]) -> None:
if node_id not in upstream:
return
for child_id in upstream[node_id]:
if child_id in contained:
continue
contained[child_id] = True
self._collect_contained(child_id, upstream, contained)
def close(self, flow, condition, dynprompt=None, unique_id=None, **kwargs):
if not condition:
return tuple(kwargs.get(f"initial_value{index}") for index in range(MAX_LOOP_VALUES))
_require_graph_builder()
upstream: dict[str, list[str]] = {}
parent_ids: list[str] = []
self._explore_dependencies(unique_id, dynprompt, upstream, parent_ids)
parent_ids = list(set(parent_ids))
output_nodes = {}
for node_id, node in dynprompt.get_original_prompt().items():
if "inputs" not in node:
continue
class_def = ALL_NODE_CLASS_MAPPINGS.get(node["class_type"])
if not class_def or not getattr(class_def, "OUTPUT_NODE", False):
continue
for value in node["inputs"].values():
if is_link(value):
output_nodes[node_id] = value
graph = GraphBuilder()
self._explore_output_nodes(dynprompt, upstream, output_nodes, parent_ids)
contained: dict[str, bool] = {}
open_node = flow[0]
self._collect_contained(open_node, upstream, contained)
contained[unique_id] = True
contained[open_node] = True
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
node.set_override_display_id(node_id)
for node_id in contained:
original_node = dynprompt.get_node(node_id)
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
for key, value in original_node["inputs"].items():
if is_link(value) and value[0] in contained:
parent = graph.lookup_node(value[0])
node.set_input(key, parent.out(value[1]))
else:
node.set_input(key, value)
new_open = graph.lookup_node(open_node)
original_open = dynprompt.get_node(open_node)
if original_open["class_type"] == "SxCPForLoopStart":
new_open.set_input("initial_index", kwargs.get("initial_value0"))
new_open.set_input("initial_collected", kwargs.get("initial_value1"))
for carry_index in range(1, MAX_CARRY_VALUES + 1):
new_open.set_input(f"initial_value{carry_index}", kwargs.get(f"initial_value{carry_index + 1}"))
else:
for index in range(MAX_LOOP_VALUES):
new_open.set_input(f"initial_value{index}", kwargs.get(f"initial_value{index}"))
my_clone = graph.lookup_node("Recurse")
return {
"result": tuple(my_clone.out(index) for index in range(MAX_LOOP_VALUES)),
"expand": graph.finalize(),
}
class SxCPForLoopStart:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}),
"skip": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}),
},
"optional": {
f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1)
},
"hidden": {
"initial_index": (ANY_TYPE,),
"initial_collected": (ANY_TYPE,),
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT", ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES)
RETURN_NAMES = tuple(["flow", "index", "collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)])
FUNCTION = "start"
CATEGORY = "prompt_builder/loop"
def start(self, total, skip=0, initial_index=None, initial_collected=None, **kwargs):
_require_graph_builder()
total = max(1, int(total))
skip = max(0, int(skip))
first_index = skip + 1
index = first_index if initial_index is None else max(int(initial_index), first_index)
collected = initial_collected
initial_values = {
"initial_value0": index,
"initial_value1": collected,
}
for carry_index in range(1, MAX_CARRY_VALUES + 1):
initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
graph = GraphBuilder()
graph.node("SxCPWhileLoopStart", condition=index <= total, **initial_values)
return {
"result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]),
"expand": graph.finalize(),
}
class SxCPLoopAppend:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mode": (COLLECTION_MODES, {"default": "auto_batch"}),
"skip_none": ("BOOLEAN", {"default": True}),
},
"optional": {
"collection": (ANY_TYPE,),
"value": (ANY_TYPE,),
},
}
RETURN_TYPES = (ANY_TYPE,)
RETURN_NAMES = ("collected",)
FUNCTION = "append"
CATEGORY = "prompt_builder/loop"
def append(self, mode, skip_none, collection=None, value=None):
return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),)
class SxCPForLoopEnd:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"flow": ("FLOW_CONTROL", {"rawLink": True}),
"collection_mode": (COLLECTION_MODES, {"default": "auto_batch"}),
"skip_none": ("BOOLEAN", {"default": True}),
},
"optional": {
"collected": (ANY_TYPE, {"rawLink": True}),
"collect_value": (ANY_TYPE, {"rawLink": True}),
**{
f"initial_value{index}": (ANY_TYPE, {"rawLink": True})
for index in range(1, MAX_CARRY_VALUES + 1)
},
},
"hidden": {
"dynprompt": "DYNPROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = tuple([ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES)
RETURN_NAMES = tuple(["collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)])
FUNCTION = "end"
CATEGORY = "prompt_builder/loop"
def end(self, flow, collection_mode, skip_none, dynprompt=None, **kwargs):
_require_graph_builder()
graph = GraphBuilder()
loop_start = flow[0]
start_node = dynprompt.get_node(loop_start)
if start_node["class_type"] != "SxCPForLoopStart":
raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.")
total = start_node["inputs"]["total"]
next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1)
condition = graph.node("SxCPLoopLessThanOrEqual", a=next_index.out(0), b=total)
collection = kwargs.get("collected") or [loop_start, 2]
collect_value = kwargs.get("collect_value")
next_collection = graph.node(
"SxCPLoopAppend",
collection=collection,
value=collect_value,
mode=collection_mode,
skip_none=skip_none,
)
next_values = {
"initial_value0": next_index.out(0),
"initial_value1": next_collection.out(0),
}
for carry_index in range(1, MAX_CARRY_VALUES + 1):
next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values)
return {
"result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)),
"expand": graph.finalize(),
}
class SxCPLoopIntAdd:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0}),
"b": ("INT", {"default": 1}),
}
}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("int",)
FUNCTION = "add"
CATEGORY = "prompt_builder/loop/internal"
def add(self, a, b):
return (int(a) + int(b),)
class SxCPLoopLessThan:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0}),
"b": ("INT", {"default": 1}),
}
}
RETURN_TYPES = ("BOOLEAN",)
RETURN_NAMES = ("boolean",)
FUNCTION = "compare"
CATEGORY = "prompt_builder/loop/internal"
def compare(self, a, b):
return (int(a) < int(b),)
class SxCPLoopLessThanOrEqual:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"a": ("INT", {"default": 0}),
"b": ("INT", {"default": 0}),
}
}
RETURN_TYPES = ("BOOLEAN",)
FUNCTION = "compare"
CATEGORY = "prompt_builder/loop/internal"
def compare(self, a, b):
return (int(a) <= int(b),)
LOOP_NODE_CLASS_MAPPINGS = {
"SxCPWhileLoopStart": SxCPWhileLoopStart,
"SxCPWhileLoopEnd": SxCPWhileLoopEnd,
"SxCPForLoopStart": SxCPForLoopStart,
"SxCPForLoopEnd": SxCPForLoopEnd,
"SxCPLoopAppend": SxCPLoopAppend,
"SxCPLoopIntAdd": SxCPLoopIntAdd,
"SxCPLoopLessThan": SxCPLoopLessThan,
"SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual,
}
LOOP_NODE_DISPLAY_NAME_MAPPINGS = {
"SxCPWhileLoopStart": "SxCP While Loop Start",
"SxCPWhileLoopEnd": "SxCP While Loop End",
"SxCPForLoopStart": "SxCP For Loop Start",
"SxCPForLoopEnd": "SxCP For Loop End",
"SxCPLoopAppend": "SxCP Loop Append",
"SxCPLoopIntAdd": "SxCP Loop Int Add",
"SxCPLoopLessThan": "SxCP Loop Less Than",
"SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal",
}