diff --git a/README.md b/README.md index bb4328e..b47487e 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ The node is registered as: - `prompt_builder / SxCP Seed Locker` - `prompt_builder / SxCP Camera Control` - `prompt_builder / SxCP Camera Orbit Control` +- `prompt_builder / SxCP For Loop Start` +- `prompt_builder / SxCP For Loop End` +- `prompt_builder / SxCP Loop Append` - `prompt_builder / SxCP Category Preset` - `prompt_builder / SxCP Cast Control` - `prompt_builder / SxCP Generation Profile` @@ -71,6 +74,35 @@ as one long chain: manually into either generation lane, but they are not part of the default main path. +## Loop Nodes + +`SxCP For Loop Start` and `SxCP For Loop End` provide a lightweight replacement +for the easy-use for-loop dependency. They use the same recursive ComfyUI loop +pattern, but add a dedicated collector output for building a result sequence. + +Basic loop wiring: + +1. Connect `For Loop Start.flow` to `For Loop End.flow`. +2. Use `For Loop Start.index` inside the loop for seed/index changes. +3. Connect the per-iteration output you want to keep, such as an image, latent, + prompt, or metadata string, to `For Loop End.collect_value`. +4. Optionally connect `For Loop Start.collected` to `For Loop End.collected`. + If omitted, the end node uses the start collector internally. +5. After the loop finishes, use `For Loop End.collected` as the combined output. + +`collection_mode` controls how values are stored: + +- `auto_batch`: concatenates image tensors or latent samples when possible, + otherwise falls back to a Python list. +- `image_batch`: prefers image tensor batching. +- `latent_batch`: prefers latent `samples` batching. +- `list`: always appends each iteration result to a list. +- `string_lines`: joins each collected value with newlines. + +`value1`, `value2`, and later slots are normal carry-through channels for state +you want to update each iteration. They are separate from the collector and grow +dynamically in the UI as you connect them. + ## Character Profiles `SxCP Woman Slot` and `SxCP Man Slot` are the scalable per-participant control diff --git a/__init__.py b/__init__.py index 2fd6889..9add51a 100644 --- a/__init__.py +++ b/__init__.py @@ -4,6 +4,7 @@ import json import random try: + from .loop_nodes import LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS from .prompt_builder import ( build_camera_config_json, build_camera_orbit_config_json, @@ -53,6 +54,7 @@ try: from .caption_naturalizer import naturalize_caption from .krea_formatter import format_krea2_prompt except ImportError: + from loop_nodes import LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS from prompt_builder import ( build_camera_config_json, build_camera_orbit_config_json, @@ -1317,6 +1319,7 @@ NODE_CLASS_MAPPINGS = { "SxCPInstaOFOptions": SxCPInstaOFOptions, "SxCPInstaOFPromptPair": SxCPInstaOFPromptPair, } +NODE_CLASS_MAPPINGS.update(LOOP_NODE_CLASS_MAPPINGS) NODE_DISPLAY_NAME_MAPPINGS = { "SxCPPromptBuilder": "SxCP Prompt Builder", @@ -1339,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SxCPInstaOFOptions": "SxCP Insta/OF Options", "SxCPInstaOFPromptPair": "SxCP Insta/OF Prompt Pair", } +NODE_DISPLAY_NAME_MAPPINGS.update(LOOP_NODE_DISPLAY_NAME_MAPPINGS) WEB_DIRECTORY = "./web" diff --git a/loop_nodes.py b/loop_nodes.py new file mode 100644 index 0000000..d940039 --- /dev/null +++ b/loop_nodes.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +from typing import Any + +try: + from comfy_execution.graph import ExecutionBlocker + from comfy_execution.graph_utils import GraphBuilder, is_link +except Exception: # Allows local syntax/import checks outside ComfyUI. + ExecutionBlocker = None + GraphBuilder = None + + def is_link(value: Any) -> bool: + return isinstance(value, list) and len(value) == 2 + +try: + from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS +except Exception: + ALL_NODE_CLASS_MAPPINGS = {} + + +MAX_LOOP_VALUES = 20 +MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2 +COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"] + + +class AnyType(str): + def __ne__(self, _other: object) -> bool: + return False + + +ANY_TYPE = AnyType("*") + + +def _require_graph_builder() -> None: + if GraphBuilder is None: + raise RuntimeError("SxCP loop nodes require ComfyUI's comfy_execution GraphBuilder.") + + +def _execution_blocker() -> Any: + return ExecutionBlocker(None) if ExecutionBlocker is not None else None + + +def _torch_cat(first: Any, second: Any) -> Any | None: + try: + import torch + except Exception: + return None + if torch.is_tensor(first) and torch.is_tensor(second): + return torch.cat((first, second), dim=0) + return None + + +def _latent_cat(first: Any, second: Any) -> Any | None: + if not isinstance(first, dict) or not isinstance(second, dict): + return None + if "samples" not in first or "samples" not in second: + return None + samples = _torch_cat(first["samples"], second["samples"]) + if samples is None: + return None + merged = dict(second) + merged["samples"] = samples + return merged + + +def _as_list(collection: Any) -> list[Any]: + if collection is None: + return [] + return list(collection) if isinstance(collection, list) else [collection] + + +def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch", skip_none: bool = True) -> Any: + if value is None and skip_none: + return collection + mode = mode if mode in COLLECTION_MODES else "auto_batch" + if mode == "string_lines": + value_text = "" if value is None else str(value) + if not collection: + return value_text + return f"{collection}\n{value_text}" + if mode == "list": + return _as_list(collection) + [value] + if collection is None: + return value + if mode in ("auto_batch", "image_batch"): + tensor_batch = _torch_cat(collection, value) + if tensor_batch is not None: + return tensor_batch + if mode == "image_batch": + return _as_list(collection) + [value] + if mode in ("auto_batch", "latent_batch"): + latent_batch = _latent_cat(collection, value) + if latent_batch is not None: + return latent_batch + if mode == "latent_batch": + return _as_list(collection) + [value] + return _as_list(collection) + [value] + + +class SxCPWhileLoopStart: + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": {}, + } + for index in range(MAX_LOOP_VALUES): + inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + [ANY_TYPE] * MAX_LOOP_VALUES) + RETURN_NAMES = tuple(["flow"] + [f"value{index}" for index in range(MAX_LOOP_VALUES)]) + FUNCTION = "open" + CATEGORY = "prompt_builder/loop" + + def open(self, condition, **kwargs): + values = [] + for index in range(MAX_LOOP_VALUES): + values.append(kwargs.get(f"initial_value{index}") if condition else _execution_blocker()) + return tuple(["stub"] + values) + + +class SxCPWhileLoopEnd: + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {}), + }, + "optional": {}, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + "extra_pnginfo": "EXTRA_PNGINFO", + }, + } + for index in range(MAX_LOOP_VALUES): + inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,) + return inputs + + RETURN_TYPES = tuple([ANY_TYPE] * MAX_LOOP_VALUES) + RETURN_NAMES = tuple([f"value{index}" for index in range(MAX_LOOP_VALUES)]) + FUNCTION = "close" + CATEGORY = "prompt_builder/loop" + + def _explore_dependencies(self, node_id: str, dynprompt: Any, upstream: dict[str, list[str]], parent_ids: list[str]) -> None: + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for value in node_info["inputs"].values(): + if not is_link(value): + continue + parent_id = value[0] + display_id = dynprompt.get_display_node_id(parent_id) + display_node = dynprompt.get_node(display_id) + class_type = display_node["class_type"] + if class_type not in ("SxCPForLoopEnd", "SxCPWhileLoopEnd"): + parent_ids.append(display_id) + if parent_id not in upstream: + upstream[parent_id] = [] + self._explore_dependencies(parent_id, dynprompt, upstream, parent_ids) + upstream[parent_id].append(node_id) + + def _explore_output_nodes( + self, + dynprompt: Any, + upstream: dict[str, list[str]], + output_nodes: dict[str, Any], + parent_ids: list[str], + ) -> None: + for parent_id in upstream: + display_id = dynprompt.get_display_node_id(parent_id) + for output_id, link in output_nodes.items(): + linked_id = link[0] + if linked_id in parent_ids and display_id == linked_id and output_id not in upstream[parent_id]: + if "." in parent_id: + parts = parent_id.split(".") + parts[-1] = output_id + upstream[parent_id].append(".".join(parts)) + else: + upstream[parent_id].append(output_id) + + def _collect_contained(self, node_id: str, upstream: dict[str, list[str]], contained: dict[str, bool]) -> None: + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id in contained: + continue + contained[child_id] = True + self._collect_contained(child_id, upstream, contained) + + def close(self, flow, condition, dynprompt=None, unique_id=None, **kwargs): + if not condition: + return tuple(kwargs.get(f"initial_value{index}") for index in range(MAX_LOOP_VALUES)) + + _require_graph_builder() + upstream: dict[str, list[str]] = {} + parent_ids: list[str] = [] + self._explore_dependencies(unique_id, dynprompt, upstream, parent_ids) + parent_ids = list(set(parent_ids)) + + output_nodes = {} + for node_id, node in dynprompt.get_original_prompt().items(): + if "inputs" not in node: + continue + class_def = ALL_NODE_CLASS_MAPPINGS.get(node["class_type"]) + if not class_def or not getattr(class_def, "OUTPUT_NODE", False): + continue + for value in node["inputs"].values(): + if is_link(value): + output_nodes[node_id] = value + + graph = GraphBuilder() + self._explore_output_nodes(dynprompt, upstream, output_nodes, parent_ids) + contained: dict[str, bool] = {} + open_node = flow[0] + self._collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + for key, value in original_node["inputs"].items(): + if is_link(value) and value[0] in contained: + parent = graph.lookup_node(value[0]) + node.set_input(key, parent.out(value[1])) + else: + node.set_input(key, value) + + new_open = graph.lookup_node(open_node) + original_open = dynprompt.get_node(open_node) + if original_open["class_type"] == "SxCPForLoopStart": + new_open.set_input("initial_index", kwargs.get("initial_value0")) + new_open.set_input("initial_collected", kwargs.get("initial_value1")) + for carry_index in range(1, MAX_CARRY_VALUES + 1): + new_open.set_input(f"initial_value{carry_index}", kwargs.get(f"initial_value{carry_index + 1}")) + else: + for index in range(MAX_LOOP_VALUES): + new_open.set_input(f"initial_value{index}", kwargs.get(f"initial_value{index}")) + my_clone = graph.lookup_node("Recurse") + return { + "result": tuple(my_clone.out(index) for index in range(MAX_LOOP_VALUES)), + "expand": graph.finalize(), + } + + +class SxCPForLoopStart: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}), + }, + "optional": { + f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1) + }, + "hidden": { + "initial_index": (ANY_TYPE,), + "initial_collected": (ANY_TYPE,), + "prompt": "PROMPT", + "extra_pnginfo": "EXTRA_PNGINFO", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT", ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) + RETURN_NAMES = tuple(["flow", "index", "collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) + FUNCTION = "start" + CATEGORY = "prompt_builder/loop" + + def start(self, total, initial_index=None, initial_collected=None, **kwargs): + _require_graph_builder() + index = 0 if initial_index is None else initial_index + collected = initial_collected + initial_values = { + "initial_value0": index, + "initial_value1": collected, + } + for carry_index in range(1, MAX_CARRY_VALUES + 1): + initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") + graph = GraphBuilder() + graph.node("SxCPWhileLoopStart", condition=total, **initial_values) + return { + "result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]), + "expand": graph.finalize(), + } + + +class SxCPLoopAppend: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mode": (COLLECTION_MODES, {"default": "auto_batch"}), + "skip_none": ("BOOLEAN", {"default": True}), + }, + "optional": { + "collection": (ANY_TYPE,), + "value": (ANY_TYPE,), + }, + } + + RETURN_TYPES = (ANY_TYPE,) + RETURN_NAMES = ("collected",) + FUNCTION = "append" + CATEGORY = "prompt_builder/loop" + + def append(self, mode, skip_none, collection=None, value=None): + return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),) + + +class SxCPForLoopEnd: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow": ("FLOW_CONTROL", {"rawLink": True}), + "collection_mode": (COLLECTION_MODES, {"default": "auto_batch"}), + "skip_none": ("BOOLEAN", {"default": True}), + }, + "optional": { + "collected": (ANY_TYPE, {"rawLink": True}), + "collect_value": (ANY_TYPE, {"rawLink": True}), + **{ + f"initial_value{index}": (ANY_TYPE, {"rawLink": True}) + for index in range(1, MAX_CARRY_VALUES + 1) + }, + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "extra_pnginfo": "EXTRA_PNGINFO", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = tuple([ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES) + RETURN_NAMES = tuple(["collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)]) + FUNCTION = "end" + CATEGORY = "prompt_builder/loop" + + def end(self, flow, collection_mode, skip_none, dynprompt=None, **kwargs): + _require_graph_builder() + graph = GraphBuilder() + loop_start = flow[0] + start_node = dynprompt.get_node(loop_start) + if start_node["class_type"] != "SxCPForLoopStart": + raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.") + total = start_node["inputs"]["total"] + next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1) + condition = graph.node("SxCPLoopLessThan", a=next_index.out(0), b=total) + collection = kwargs.get("collected") or [loop_start, 2] + collect_value = kwargs.get("collect_value") + next_collection = graph.node( + "SxCPLoopAppend", + collection=collection, + value=collect_value, + mode=collection_mode, + skip_none=skip_none, + ) + next_values = { + "initial_value0": next_index.out(0), + "initial_value1": next_collection.out(0), + } + for carry_index in range(1, MAX_CARRY_VALUES + 1): + next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") + while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values) + return { + "result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)), + "expand": graph.finalize(), + } + + +class SxCPLoopIntAdd: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 1}), + } + } + + RETURN_TYPES = ("INT",) + RETURN_NAMES = ("int",) + FUNCTION = "add" + CATEGORY = "prompt_builder/loop/internal" + + def add(self, a, b): + return (int(a) + int(b),) + + +class SxCPLoopLessThan: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0}), + "b": ("INT", {"default": 1}), + } + } + + RETURN_TYPES = ("BOOLEAN",) + RETURN_NAMES = ("boolean",) + FUNCTION = "compare" + CATEGORY = "prompt_builder/loop/internal" + + def compare(self, a, b): + return (int(a) < int(b),) + + +LOOP_NODE_CLASS_MAPPINGS = { + "SxCPWhileLoopStart": SxCPWhileLoopStart, + "SxCPWhileLoopEnd": SxCPWhileLoopEnd, + "SxCPForLoopStart": SxCPForLoopStart, + "SxCPForLoopEnd": SxCPForLoopEnd, + "SxCPLoopAppend": SxCPLoopAppend, + "SxCPLoopIntAdd": SxCPLoopIntAdd, + "SxCPLoopLessThan": SxCPLoopLessThan, +} + +LOOP_NODE_DISPLAY_NAME_MAPPINGS = { + "SxCPWhileLoopStart": "SxCP While Loop Start", + "SxCPWhileLoopEnd": "SxCP While Loop End", + "SxCPForLoopStart": "SxCP For Loop Start", + "SxCPForLoopEnd": "SxCP For Loop End", + "SxCPLoopAppend": "SxCP Loop Append", + "SxCPLoopIntAdd": "SxCP Loop Int Add", + "SxCPLoopLessThan": "SxCP Loop Less Than", +} diff --git a/web/loop_slots.js b/web/loop_slots.js new file mode 100644 index 0000000..b037232 --- /dev/null +++ b/web/loop_slots.js @@ -0,0 +1,130 @@ +import { app } from "../../scripts/app.js"; + +const EXTENSION = "ethanfel.prompt_builder.loop_slots"; +const LOOP_NODES = new Set(["SxCPForLoopStart", "SxCPForLoopEnd", "SxCPWhileLoopStart", "SxCPWhileLoopEnd"]); +const MAX_CARRY = 18; + +function isCarryInput(input) { + return /^initial_value\d+$/.test(input?.name || ""); +} + +function isCarryOutput(output) { + return /^value\d+$/.test(output?.name || ""); +} + +function carryNumber(slot) { + const match = String(slot?.name || "").match(/\d+$/); + return match ? Number(match[0]) : -1; +} + +function resizeNode(node) { + const size = node.computeSize?.(); + if (size) node.setSize?.(size); + app.graph?.setDirtyCanvas(true, true); +} + +function getCarryLimit(nodeName) { + return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 19 : MAX_CARRY; +} + +function getFirstCarry(nodeName) { + return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 0 : 1; +} + +function addCarryPair(node, nodeName, number) { + if (number > getCarryLimit(nodeName)) return; + const inputName = `initial_value${number}`; + const outputName = `value${number}`; + if (!node.inputs?.some((input) => input.name === inputName)) node.addInput(inputName, "*"); + if (!node.outputs?.some((output) => output.name === outputName)) node.addOutput(outputName, "*"); +} + +function removeCarryPair(node, number) { + const inputIndex = node.inputs?.findIndex((input) => input.name === `initial_value${number}`) ?? -1; + if (inputIndex >= 0 && !node.inputs[inputIndex]?.link) node.removeInput(inputIndex); + const outputIndex = node.outputs?.findIndex((output) => output.name === `value${number}`) ?? -1; + if (outputIndex >= 0 && !(node.outputs[outputIndex]?.links?.length)) node.removeOutput(outputIndex); +} + +function trimCarryTail(node, nodeName) { + const first = getFirstCarry(nodeName); + for (let number = getCarryLimit(nodeName); number > first; number--) { + const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`); + const output = node.outputs?.find((slot) => slot.name === `value${number}`); + const previousInput = node.inputs?.find((slot) => slot.name === `initial_value${number - 1}`); + const previousOutput = node.outputs?.find((slot) => slot.name === `value${number - 1}`); + const currentUsed = Boolean(input?.link || output?.links?.length); + const previousUsed = Boolean(previousInput?.link || previousOutput?.links?.length); + if (!currentUsed && !previousUsed) removeCarryPair(node, number); + } +} + +function setupNodeSlots(node, nodeName) { + const first = getFirstCarry(nodeName); + const limit = getCarryLimit(nodeName); + addCarryPair(node, nodeName, first); + for (let number = first + 1; number <= limit; number++) { + const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`); + const output = node.outputs?.find((slot) => slot.name === `value${number}`); + if (!input?.link && !output?.links?.length) removeCarryPair(node, number); + } + for (const output of node.outputs || []) { + if (output.name === "flow") output.shape = 5; + } + for (const input of node.inputs || []) { + if (input.name === "flow") input.shape = 5; + } + trimCarryTail(node, nodeName); + resizeNode(node); +} + +function maybeGrow(node, nodeName) { + const carryInputs = (node.inputs || []).filter(isCarryInput); + const carryOutputs = (node.outputs || []).filter(isCarryOutput); + const lastInput = carryInputs.reduce((max, input) => Math.max(max, carryNumber(input)), -1); + const lastOutput = carryOutputs.reduce((max, output) => Math.max(max, carryNumber(output)), -1); + const last = Math.max(lastInput, lastOutput, getFirstCarry(nodeName)); + const input = node.inputs?.find((slot) => slot.name === `initial_value${last}`); + const output = node.outputs?.find((slot) => slot.name === `value${last}`); + if ((input?.link || output?.links?.length) && last < getCarryLimit(nodeName)) { + addCarryPair(node, nodeName, last + 1); + resizeNode(node); + } +} + +app.registerExtension({ + name: EXTENSION, + + async beforeRegisterNodeDef(nodeType, nodeData) { + if (!LOOP_NODES.has(nodeData.name)) return; + + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const result = onNodeCreated?.apply(this, arguments); + queueMicrotask(() => setupNodeSlots(this, nodeData.name)); + return result; + }; + + const onConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function () { + const result = onConfigure?.apply(this, arguments); + queueMicrotask(() => setupNodeSlots(this, nodeData.name)); + return result; + }; + + const onConnectionsChange = nodeType.prototype.onConnectionsChange; + nodeType.prototype.onConnectionsChange = function (type, index, connected, linkInfo) { + const result = onConnectionsChange?.apply(this, arguments); + if (!linkInfo) return result; + const slot = type === LiteGraph.INPUT ? this.inputs?.[index] : this.outputs?.[index]; + if (isCarryInput(slot) || isCarryOutput(slot)) { + if (connected) maybeGrow(this, nodeData.name); + else { + trimCarryTail(this, nodeData.name); + resizeNode(this); + } + } + return result; + }; + }, +});