Add collected for-loop nodes
This commit is contained in:
@@ -10,6 +10,9 @@ The node is registered as:
|
|||||||
- `prompt_builder / SxCP Seed Locker`
|
- `prompt_builder / SxCP Seed Locker`
|
||||||
- `prompt_builder / SxCP Camera Control`
|
- `prompt_builder / SxCP Camera Control`
|
||||||
- `prompt_builder / SxCP Camera Orbit Control`
|
- `prompt_builder / SxCP Camera Orbit Control`
|
||||||
|
- `prompt_builder / SxCP For Loop Start`
|
||||||
|
- `prompt_builder / SxCP For Loop End`
|
||||||
|
- `prompt_builder / SxCP Loop Append`
|
||||||
- `prompt_builder / SxCP Category Preset`
|
- `prompt_builder / SxCP Category Preset`
|
||||||
- `prompt_builder / SxCP Cast Control`
|
- `prompt_builder / SxCP Cast Control`
|
||||||
- `prompt_builder / SxCP Generation Profile`
|
- `prompt_builder / SxCP Generation Profile`
|
||||||
@@ -71,6 +74,35 @@ as one long chain:
|
|||||||
manually into either generation lane, but they are not part of the default
|
manually into either generation lane, but they are not part of the default
|
||||||
main path.
|
main path.
|
||||||
|
|
||||||
|
## Loop Nodes
|
||||||
|
|
||||||
|
`SxCP For Loop Start` and `SxCP For Loop End` provide a lightweight replacement
|
||||||
|
for the easy-use for-loop dependency. They use the same recursive ComfyUI loop
|
||||||
|
pattern, but add a dedicated collector output for building a result sequence.
|
||||||
|
|
||||||
|
Basic loop wiring:
|
||||||
|
|
||||||
|
1. Connect `For Loop Start.flow` to `For Loop End.flow`.
|
||||||
|
2. Use `For Loop Start.index` inside the loop for seed/index changes.
|
||||||
|
3. Connect the per-iteration output you want to keep, such as an image, latent,
|
||||||
|
prompt, or metadata string, to `For Loop End.collect_value`.
|
||||||
|
4. Optionally connect `For Loop Start.collected` to `For Loop End.collected`.
|
||||||
|
If omitted, the end node uses the start collector internally.
|
||||||
|
5. After the loop finishes, use `For Loop End.collected` as the combined output.
|
||||||
|
|
||||||
|
`collection_mode` controls how values are stored:
|
||||||
|
|
||||||
|
- `auto_batch`: concatenates image tensors or latent samples when possible,
|
||||||
|
otherwise falls back to a Python list.
|
||||||
|
- `image_batch`: prefers image tensor batching.
|
||||||
|
- `latent_batch`: prefers latent `samples` batching.
|
||||||
|
- `list`: always appends each iteration result to a list.
|
||||||
|
- `string_lines`: joins each collected value with newlines.
|
||||||
|
|
||||||
|
`value1`, `value2`, and later slots are normal carry-through channels for state
|
||||||
|
you want to update each iteration. They are separate from the collector and grow
|
||||||
|
dynamically in the UI as you connect them.
|
||||||
|
|
||||||
## Character Profiles
|
## Character Profiles
|
||||||
|
|
||||||
`SxCP Woman Slot` and `SxCP Man Slot` are the scalable per-participant control
|
`SxCP Woman Slot` and `SxCP Man Slot` are the scalable per-participant control
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from .loop_nodes import LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from .prompt_builder import (
|
from .prompt_builder import (
|
||||||
build_camera_config_json,
|
build_camera_config_json,
|
||||||
build_camera_orbit_config_json,
|
build_camera_orbit_config_json,
|
||||||
@@ -53,6 +54,7 @@ try:
|
|||||||
from .caption_naturalizer import naturalize_caption
|
from .caption_naturalizer import naturalize_caption
|
||||||
from .krea_formatter import format_krea2_prompt
|
from .krea_formatter import format_krea2_prompt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
from loop_nodes import LOOP_NODE_CLASS_MAPPINGS, LOOP_NODE_DISPLAY_NAME_MAPPINGS
|
||||||
from prompt_builder import (
|
from prompt_builder import (
|
||||||
build_camera_config_json,
|
build_camera_config_json,
|
||||||
build_camera_orbit_config_json,
|
build_camera_orbit_config_json,
|
||||||
@@ -1317,6 +1319,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"SxCPInstaOFOptions": SxCPInstaOFOptions,
|
"SxCPInstaOFOptions": SxCPInstaOFOptions,
|
||||||
"SxCPInstaOFPromptPair": SxCPInstaOFPromptPair,
|
"SxCPInstaOFPromptPair": SxCPInstaOFPromptPair,
|
||||||
}
|
}
|
||||||
|
NODE_CLASS_MAPPINGS.update(LOOP_NODE_CLASS_MAPPINGS)
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"SxCPPromptBuilder": "SxCP Prompt Builder",
|
"SxCPPromptBuilder": "SxCP Prompt Builder",
|
||||||
@@ -1339,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"SxCPInstaOFOptions": "SxCP Insta/OF Options",
|
"SxCPInstaOFOptions": "SxCP Insta/OF Options",
|
||||||
"SxCPInstaOFPromptPair": "SxCP Insta/OF Prompt Pair",
|
"SxCPInstaOFPromptPair": "SxCP Insta/OF Prompt Pair",
|
||||||
}
|
}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS.update(LOOP_NODE_DISPLAY_NAME_MAPPINGS)
|
||||||
|
|
||||||
WEB_DIRECTORY = "./web"
|
WEB_DIRECTORY = "./web"
|
||||||
|
|
||||||
|
|||||||
+437
@@ -0,0 +1,437 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
from comfy_execution.graph import ExecutionBlocker
|
||||||
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
|
except Exception: # Allows local syntax/import checks outside ComfyUI.
|
||||||
|
ExecutionBlocker = None
|
||||||
|
GraphBuilder = None
|
||||||
|
|
||||||
|
def is_link(value: Any) -> bool:
|
||||||
|
return isinstance(value, list) and len(value) == 2
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nodes import NODE_CLASS_MAPPINGS as ALL_NODE_CLASS_MAPPINGS
|
||||||
|
except Exception:
|
||||||
|
ALL_NODE_CLASS_MAPPINGS = {}
|
||||||
|
|
||||||
|
|
||||||
|
MAX_LOOP_VALUES = 20
|
||||||
|
MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2
|
||||||
|
COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"]
|
||||||
|
|
||||||
|
|
||||||
|
class AnyType(str):
|
||||||
|
def __ne__(self, _other: object) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
ANY_TYPE = AnyType("*")
|
||||||
|
|
||||||
|
|
||||||
|
def _require_graph_builder() -> None:
|
||||||
|
if GraphBuilder is None:
|
||||||
|
raise RuntimeError("SxCP loop nodes require ComfyUI's comfy_execution GraphBuilder.")
|
||||||
|
|
||||||
|
|
||||||
|
def _execution_blocker() -> Any:
|
||||||
|
return ExecutionBlocker(None) if ExecutionBlocker is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_cat(first: Any, second: Any) -> Any | None:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if torch.is_tensor(first) and torch.is_tensor(second):
|
||||||
|
return torch.cat((first, second), dim=0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _latent_cat(first: Any, second: Any) -> Any | None:
|
||||||
|
if not isinstance(first, dict) or not isinstance(second, dict):
|
||||||
|
return None
|
||||||
|
if "samples" not in first or "samples" not in second:
|
||||||
|
return None
|
||||||
|
samples = _torch_cat(first["samples"], second["samples"])
|
||||||
|
if samples is None:
|
||||||
|
return None
|
||||||
|
merged = dict(second)
|
||||||
|
merged["samples"] = samples
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _as_list(collection: Any) -> list[Any]:
|
||||||
|
if collection is None:
|
||||||
|
return []
|
||||||
|
return list(collection) if isinstance(collection, list) else [collection]
|
||||||
|
|
||||||
|
|
||||||
|
def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch", skip_none: bool = True) -> Any:
|
||||||
|
if value is None and skip_none:
|
||||||
|
return collection
|
||||||
|
mode = mode if mode in COLLECTION_MODES else "auto_batch"
|
||||||
|
if mode == "string_lines":
|
||||||
|
value_text = "" if value is None else str(value)
|
||||||
|
if not collection:
|
||||||
|
return value_text
|
||||||
|
return f"{collection}\n{value_text}"
|
||||||
|
if mode == "list":
|
||||||
|
return _as_list(collection) + [value]
|
||||||
|
if collection is None:
|
||||||
|
return value
|
||||||
|
if mode in ("auto_batch", "image_batch"):
|
||||||
|
tensor_batch = _torch_cat(collection, value)
|
||||||
|
if tensor_batch is not None:
|
||||||
|
return tensor_batch
|
||||||
|
if mode == "image_batch":
|
||||||
|
return _as_list(collection) + [value]
|
||||||
|
if mode in ("auto_batch", "latent_batch"):
|
||||||
|
latent_batch = _latent_cat(collection, value)
|
||||||
|
if latent_batch is not None:
|
||||||
|
return latent_batch
|
||||||
|
if mode == "latent_batch":
|
||||||
|
return _as_list(collection) + [value]
|
||||||
|
return _as_list(collection) + [value]
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPWhileLoopStart:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
inputs = {
|
||||||
|
"required": {
|
||||||
|
"condition": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {},
|
||||||
|
}
|
||||||
|
for index in range(MAX_LOOP_VALUES):
|
||||||
|
inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
RETURN_TYPES = tuple(["FLOW_CONTROL"] + [ANY_TYPE] * MAX_LOOP_VALUES)
|
||||||
|
RETURN_NAMES = tuple(["flow"] + [f"value{index}" for index in range(MAX_LOOP_VALUES)])
|
||||||
|
FUNCTION = "open"
|
||||||
|
CATEGORY = "prompt_builder/loop"
|
||||||
|
|
||||||
|
def open(self, condition, **kwargs):
|
||||||
|
values = []
|
||||||
|
for index in range(MAX_LOOP_VALUES):
|
||||||
|
values.append(kwargs.get(f"initial_value{index}") if condition else _execution_blocker())
|
||||||
|
return tuple(["stub"] + values)
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPWhileLoopEnd:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
inputs = {
|
||||||
|
"required": {
|
||||||
|
"flow": ("FLOW_CONTROL", {"rawLink": True}),
|
||||||
|
"condition": ("BOOLEAN", {}),
|
||||||
|
},
|
||||||
|
"optional": {},
|
||||||
|
"hidden": {
|
||||||
|
"dynprompt": "DYNPROMPT",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for index in range(MAX_LOOP_VALUES):
|
||||||
|
inputs["optional"][f"initial_value{index}"] = (ANY_TYPE,)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
RETURN_TYPES = tuple([ANY_TYPE] * MAX_LOOP_VALUES)
|
||||||
|
RETURN_NAMES = tuple([f"value{index}" for index in range(MAX_LOOP_VALUES)])
|
||||||
|
FUNCTION = "close"
|
||||||
|
CATEGORY = "prompt_builder/loop"
|
||||||
|
|
||||||
|
def _explore_dependencies(self, node_id: str, dynprompt: Any, upstream: dict[str, list[str]], parent_ids: list[str]) -> None:
|
||||||
|
node_info = dynprompt.get_node(node_id)
|
||||||
|
if "inputs" not in node_info:
|
||||||
|
return
|
||||||
|
for value in node_info["inputs"].values():
|
||||||
|
if not is_link(value):
|
||||||
|
continue
|
||||||
|
parent_id = value[0]
|
||||||
|
display_id = dynprompt.get_display_node_id(parent_id)
|
||||||
|
display_node = dynprompt.get_node(display_id)
|
||||||
|
class_type = display_node["class_type"]
|
||||||
|
if class_type not in ("SxCPForLoopEnd", "SxCPWhileLoopEnd"):
|
||||||
|
parent_ids.append(display_id)
|
||||||
|
if parent_id not in upstream:
|
||||||
|
upstream[parent_id] = []
|
||||||
|
self._explore_dependencies(parent_id, dynprompt, upstream, parent_ids)
|
||||||
|
upstream[parent_id].append(node_id)
|
||||||
|
|
||||||
|
def _explore_output_nodes(
|
||||||
|
self,
|
||||||
|
dynprompt: Any,
|
||||||
|
upstream: dict[str, list[str]],
|
||||||
|
output_nodes: dict[str, Any],
|
||||||
|
parent_ids: list[str],
|
||||||
|
) -> None:
|
||||||
|
for parent_id in upstream:
|
||||||
|
display_id = dynprompt.get_display_node_id(parent_id)
|
||||||
|
for output_id, link in output_nodes.items():
|
||||||
|
linked_id = link[0]
|
||||||
|
if linked_id in parent_ids and display_id == linked_id and output_id not in upstream[parent_id]:
|
||||||
|
if "." in parent_id:
|
||||||
|
parts = parent_id.split(".")
|
||||||
|
parts[-1] = output_id
|
||||||
|
upstream[parent_id].append(".".join(parts))
|
||||||
|
else:
|
||||||
|
upstream[parent_id].append(output_id)
|
||||||
|
|
||||||
|
def _collect_contained(self, node_id: str, upstream: dict[str, list[str]], contained: dict[str, bool]) -> None:
|
||||||
|
if node_id not in upstream:
|
||||||
|
return
|
||||||
|
for child_id in upstream[node_id]:
|
||||||
|
if child_id in contained:
|
||||||
|
continue
|
||||||
|
contained[child_id] = True
|
||||||
|
self._collect_contained(child_id, upstream, contained)
|
||||||
|
|
||||||
|
def close(self, flow, condition, dynprompt=None, unique_id=None, **kwargs):
|
||||||
|
if not condition:
|
||||||
|
return tuple(kwargs.get(f"initial_value{index}") for index in range(MAX_LOOP_VALUES))
|
||||||
|
|
||||||
|
_require_graph_builder()
|
||||||
|
upstream: dict[str, list[str]] = {}
|
||||||
|
parent_ids: list[str] = []
|
||||||
|
self._explore_dependencies(unique_id, dynprompt, upstream, parent_ids)
|
||||||
|
parent_ids = list(set(parent_ids))
|
||||||
|
|
||||||
|
output_nodes = {}
|
||||||
|
for node_id, node in dynprompt.get_original_prompt().items():
|
||||||
|
if "inputs" not in node:
|
||||||
|
continue
|
||||||
|
class_def = ALL_NODE_CLASS_MAPPINGS.get(node["class_type"])
|
||||||
|
if not class_def or not getattr(class_def, "OUTPUT_NODE", False):
|
||||||
|
continue
|
||||||
|
for value in node["inputs"].values():
|
||||||
|
if is_link(value):
|
||||||
|
output_nodes[node_id] = value
|
||||||
|
|
||||||
|
graph = GraphBuilder()
|
||||||
|
self._explore_output_nodes(dynprompt, upstream, output_nodes, parent_ids)
|
||||||
|
contained: dict[str, bool] = {}
|
||||||
|
open_node = flow[0]
|
||||||
|
self._collect_contained(open_node, upstream, contained)
|
||||||
|
contained[unique_id] = True
|
||||||
|
contained[open_node] = True
|
||||||
|
|
||||||
|
for node_id in contained:
|
||||||
|
original_node = dynprompt.get_node(node_id)
|
||||||
|
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||||
|
node.set_override_display_id(node_id)
|
||||||
|
for node_id in contained:
|
||||||
|
original_node = dynprompt.get_node(node_id)
|
||||||
|
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||||
|
for key, value in original_node["inputs"].items():
|
||||||
|
if is_link(value) and value[0] in contained:
|
||||||
|
parent = graph.lookup_node(value[0])
|
||||||
|
node.set_input(key, parent.out(value[1]))
|
||||||
|
else:
|
||||||
|
node.set_input(key, value)
|
||||||
|
|
||||||
|
new_open = graph.lookup_node(open_node)
|
||||||
|
original_open = dynprompt.get_node(open_node)
|
||||||
|
if original_open["class_type"] == "SxCPForLoopStart":
|
||||||
|
new_open.set_input("initial_index", kwargs.get("initial_value0"))
|
||||||
|
new_open.set_input("initial_collected", kwargs.get("initial_value1"))
|
||||||
|
for carry_index in range(1, MAX_CARRY_VALUES + 1):
|
||||||
|
new_open.set_input(f"initial_value{carry_index}", kwargs.get(f"initial_value{carry_index + 1}"))
|
||||||
|
else:
|
||||||
|
for index in range(MAX_LOOP_VALUES):
|
||||||
|
new_open.set_input(f"initial_value{index}", kwargs.get(f"initial_value{index}"))
|
||||||
|
my_clone = graph.lookup_node("Recurse")
|
||||||
|
return {
|
||||||
|
"result": tuple(my_clone.out(index) for index in range(MAX_LOOP_VALUES)),
|
||||||
|
"expand": graph.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPForLoopStart:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1)
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"initial_index": (ANY_TYPE,),
|
||||||
|
"initial_collected": (ANY_TYPE,),
|
||||||
|
"prompt": "PROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT", ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES)
|
||||||
|
RETURN_NAMES = tuple(["flow", "index", "collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)])
|
||||||
|
FUNCTION = "start"
|
||||||
|
CATEGORY = "prompt_builder/loop"
|
||||||
|
|
||||||
|
def start(self, total, initial_index=None, initial_collected=None, **kwargs):
|
||||||
|
_require_graph_builder()
|
||||||
|
index = 0 if initial_index is None else initial_index
|
||||||
|
collected = initial_collected
|
||||||
|
initial_values = {
|
||||||
|
"initial_value0": index,
|
||||||
|
"initial_value1": collected,
|
||||||
|
}
|
||||||
|
for carry_index in range(1, MAX_CARRY_VALUES + 1):
|
||||||
|
initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
|
||||||
|
graph = GraphBuilder()
|
||||||
|
graph.node("SxCPWhileLoopStart", condition=total, **initial_values)
|
||||||
|
return {
|
||||||
|
"result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]),
|
||||||
|
"expand": graph.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPLoopAppend:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"mode": (COLLECTION_MODES, {"default": "auto_batch"}),
|
||||||
|
"skip_none": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"collection": (ANY_TYPE,),
|
||||||
|
"value": (ANY_TYPE,),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = (ANY_TYPE,)
|
||||||
|
RETURN_NAMES = ("collected",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
CATEGORY = "prompt_builder/loop"
|
||||||
|
|
||||||
|
def append(self, mode, skip_none, collection=None, value=None):
|
||||||
|
return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),)
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPForLoopEnd:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"flow": ("FLOW_CONTROL", {"rawLink": True}),
|
||||||
|
"collection_mode": (COLLECTION_MODES, {"default": "auto_batch"}),
|
||||||
|
"skip_none": ("BOOLEAN", {"default": True}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"collected": (ANY_TYPE, {"rawLink": True}),
|
||||||
|
"collect_value": (ANY_TYPE, {"rawLink": True}),
|
||||||
|
**{
|
||||||
|
f"initial_value{index}": (ANY_TYPE, {"rawLink": True})
|
||||||
|
for index in range(1, MAX_CARRY_VALUES + 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"dynprompt": "DYNPROMPT",
|
||||||
|
"extra_pnginfo": "EXTRA_PNGINFO",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = tuple([ANY_TYPE] + [ANY_TYPE] * MAX_CARRY_VALUES)
|
||||||
|
RETURN_NAMES = tuple(["collected"] + [f"value{index}" for index in range(1, MAX_CARRY_VALUES + 1)])
|
||||||
|
FUNCTION = "end"
|
||||||
|
CATEGORY = "prompt_builder/loop"
|
||||||
|
|
||||||
|
def end(self, flow, collection_mode, skip_none, dynprompt=None, **kwargs):
|
||||||
|
_require_graph_builder()
|
||||||
|
graph = GraphBuilder()
|
||||||
|
loop_start = flow[0]
|
||||||
|
start_node = dynprompt.get_node(loop_start)
|
||||||
|
if start_node["class_type"] != "SxCPForLoopStart":
|
||||||
|
raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.")
|
||||||
|
total = start_node["inputs"]["total"]
|
||||||
|
next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1)
|
||||||
|
condition = graph.node("SxCPLoopLessThan", a=next_index.out(0), b=total)
|
||||||
|
collection = kwargs.get("collected") or [loop_start, 2]
|
||||||
|
collect_value = kwargs.get("collect_value")
|
||||||
|
next_collection = graph.node(
|
||||||
|
"SxCPLoopAppend",
|
||||||
|
collection=collection,
|
||||||
|
value=collect_value,
|
||||||
|
mode=collection_mode,
|
||||||
|
skip_none=skip_none,
|
||||||
|
)
|
||||||
|
next_values = {
|
||||||
|
"initial_value0": next_index.out(0),
|
||||||
|
"initial_value1": next_collection.out(0),
|
||||||
|
}
|
||||||
|
for carry_index in range(1, MAX_CARRY_VALUES + 1):
|
||||||
|
next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
|
||||||
|
while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values)
|
||||||
|
return {
|
||||||
|
"result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)),
|
||||||
|
"expand": graph.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPLoopIntAdd:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"a": ("INT", {"default": 0}),
|
||||||
|
"b": ("INT", {"default": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("INT",)
|
||||||
|
RETURN_NAMES = ("int",)
|
||||||
|
FUNCTION = "add"
|
||||||
|
CATEGORY = "prompt_builder/loop/internal"
|
||||||
|
|
||||||
|
def add(self, a, b):
|
||||||
|
return (int(a) + int(b),)
|
||||||
|
|
||||||
|
|
||||||
|
class SxCPLoopLessThan:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"a": ("INT", {"default": 0}),
|
||||||
|
"b": ("INT", {"default": 1}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("BOOLEAN",)
|
||||||
|
RETURN_NAMES = ("boolean",)
|
||||||
|
FUNCTION = "compare"
|
||||||
|
CATEGORY = "prompt_builder/loop/internal"
|
||||||
|
|
||||||
|
def compare(self, a, b):
|
||||||
|
return (int(a) < int(b),)
|
||||||
|
|
||||||
|
|
||||||
|
LOOP_NODE_CLASS_MAPPINGS = {
|
||||||
|
"SxCPWhileLoopStart": SxCPWhileLoopStart,
|
||||||
|
"SxCPWhileLoopEnd": SxCPWhileLoopEnd,
|
||||||
|
"SxCPForLoopStart": SxCPForLoopStart,
|
||||||
|
"SxCPForLoopEnd": SxCPForLoopEnd,
|
||||||
|
"SxCPLoopAppend": SxCPLoopAppend,
|
||||||
|
"SxCPLoopIntAdd": SxCPLoopIntAdd,
|
||||||
|
"SxCPLoopLessThan": SxCPLoopLessThan,
|
||||||
|
}
|
||||||
|
|
||||||
|
LOOP_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SxCPWhileLoopStart": "SxCP While Loop Start",
|
||||||
|
"SxCPWhileLoopEnd": "SxCP While Loop End",
|
||||||
|
"SxCPForLoopStart": "SxCP For Loop Start",
|
||||||
|
"SxCPForLoopEnd": "SxCP For Loop End",
|
||||||
|
"SxCPLoopAppend": "SxCP Loop Append",
|
||||||
|
"SxCPLoopIntAdd": "SxCP Loop Int Add",
|
||||||
|
"SxCPLoopLessThan": "SxCP Loop Less Than",
|
||||||
|
}
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
|
||||||
|
const EXTENSION = "ethanfel.prompt_builder.loop_slots";
|
||||||
|
const LOOP_NODES = new Set(["SxCPForLoopStart", "SxCPForLoopEnd", "SxCPWhileLoopStart", "SxCPWhileLoopEnd"]);
|
||||||
|
const MAX_CARRY = 18;
|
||||||
|
|
||||||
|
function isCarryInput(input) {
|
||||||
|
return /^initial_value\d+$/.test(input?.name || "");
|
||||||
|
}
|
||||||
|
|
||||||
|
function isCarryOutput(output) {
|
||||||
|
return /^value\d+$/.test(output?.name || "");
|
||||||
|
}
|
||||||
|
|
||||||
|
function carryNumber(slot) {
|
||||||
|
const match = String(slot?.name || "").match(/\d+$/);
|
||||||
|
return match ? Number(match[0]) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
function resizeNode(node) {
|
||||||
|
const size = node.computeSize?.();
|
||||||
|
if (size) node.setSize?.(size);
|
||||||
|
app.graph?.setDirtyCanvas(true, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
function getCarryLimit(nodeName) {
|
||||||
|
return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 19 : MAX_CARRY;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getFirstCarry(nodeName) {
|
||||||
|
return nodeName === "SxCPWhileLoopStart" || nodeName === "SxCPWhileLoopEnd" ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
function addCarryPair(node, nodeName, number) {
|
||||||
|
if (number > getCarryLimit(nodeName)) return;
|
||||||
|
const inputName = `initial_value${number}`;
|
||||||
|
const outputName = `value${number}`;
|
||||||
|
if (!node.inputs?.some((input) => input.name === inputName)) node.addInput(inputName, "*");
|
||||||
|
if (!node.outputs?.some((output) => output.name === outputName)) node.addOutput(outputName, "*");
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeCarryPair(node, number) {
|
||||||
|
const inputIndex = node.inputs?.findIndex((input) => input.name === `initial_value${number}`) ?? -1;
|
||||||
|
if (inputIndex >= 0 && !node.inputs[inputIndex]?.link) node.removeInput(inputIndex);
|
||||||
|
const outputIndex = node.outputs?.findIndex((output) => output.name === `value${number}`) ?? -1;
|
||||||
|
if (outputIndex >= 0 && !(node.outputs[outputIndex]?.links?.length)) node.removeOutput(outputIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
function trimCarryTail(node, nodeName) {
|
||||||
|
const first = getFirstCarry(nodeName);
|
||||||
|
for (let number = getCarryLimit(nodeName); number > first; number--) {
|
||||||
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`);
|
||||||
|
const output = node.outputs?.find((slot) => slot.name === `value${number}`);
|
||||||
|
const previousInput = node.inputs?.find((slot) => slot.name === `initial_value${number - 1}`);
|
||||||
|
const previousOutput = node.outputs?.find((slot) => slot.name === `value${number - 1}`);
|
||||||
|
const currentUsed = Boolean(input?.link || output?.links?.length);
|
||||||
|
const previousUsed = Boolean(previousInput?.link || previousOutput?.links?.length);
|
||||||
|
if (!currentUsed && !previousUsed) removeCarryPair(node, number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setupNodeSlots(node, nodeName) {
|
||||||
|
const first = getFirstCarry(nodeName);
|
||||||
|
const limit = getCarryLimit(nodeName);
|
||||||
|
addCarryPair(node, nodeName, first);
|
||||||
|
for (let number = first + 1; number <= limit; number++) {
|
||||||
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${number}`);
|
||||||
|
const output = node.outputs?.find((slot) => slot.name === `value${number}`);
|
||||||
|
if (!input?.link && !output?.links?.length) removeCarryPair(node, number);
|
||||||
|
}
|
||||||
|
for (const output of node.outputs || []) {
|
||||||
|
if (output.name === "flow") output.shape = 5;
|
||||||
|
}
|
||||||
|
for (const input of node.inputs || []) {
|
||||||
|
if (input.name === "flow") input.shape = 5;
|
||||||
|
}
|
||||||
|
trimCarryTail(node, nodeName);
|
||||||
|
resizeNode(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
function maybeGrow(node, nodeName) {
|
||||||
|
const carryInputs = (node.inputs || []).filter(isCarryInput);
|
||||||
|
const carryOutputs = (node.outputs || []).filter(isCarryOutput);
|
||||||
|
const lastInput = carryInputs.reduce((max, input) => Math.max(max, carryNumber(input)), -1);
|
||||||
|
const lastOutput = carryOutputs.reduce((max, output) => Math.max(max, carryNumber(output)), -1);
|
||||||
|
const last = Math.max(lastInput, lastOutput, getFirstCarry(nodeName));
|
||||||
|
const input = node.inputs?.find((slot) => slot.name === `initial_value${last}`);
|
||||||
|
const output = node.outputs?.find((slot) => slot.name === `value${last}`);
|
||||||
|
if ((input?.link || output?.links?.length) && last < getCarryLimit(nodeName)) {
|
||||||
|
addCarryPair(node, nodeName, last + 1);
|
||||||
|
resizeNode(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.registerExtension({
|
||||||
|
name: EXTENSION,
|
||||||
|
|
||||||
|
async beforeRegisterNodeDef(nodeType, nodeData) {
|
||||||
|
if (!LOOP_NODES.has(nodeData.name)) return;
|
||||||
|
|
||||||
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
|
const result = onNodeCreated?.apply(this, arguments);
|
||||||
|
queueMicrotask(() => setupNodeSlots(this, nodeData.name));
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
const onConfigure = nodeType.prototype.onConfigure;
|
||||||
|
nodeType.prototype.onConfigure = function () {
|
||||||
|
const result = onConfigure?.apply(this, arguments);
|
||||||
|
queueMicrotask(() => setupNodeSlots(this, nodeData.name));
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
const onConnectionsChange = nodeType.prototype.onConnectionsChange;
|
||||||
|
nodeType.prototype.onConnectionsChange = function (type, index, connected, linkInfo) {
|
||||||
|
const result = onConnectionsChange?.apply(this, arguments);
|
||||||
|
if (!linkInfo) return result;
|
||||||
|
const slot = type === LiteGraph.INPUT ? this.inputs?.[index] : this.outputs?.[index];
|
||||||
|
if (isCarryInput(slot) || isCarryOutput(slot)) {
|
||||||
|
if (connected) maybeGrow(this, nodeData.name);
|
||||||
|
else {
|
||||||
|
trimCarryTail(this, nodeData.name);
|
||||||
|
resizeNode(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user