Add optional loop schedule input

This commit is contained in:
2026-06-28 08:19:43 +02:00
parent e434bd66ad
commit debb6d6f38
4 changed files with 190 additions and 17 deletions
+4 -4
View File
@@ -191,10 +191,10 @@ Basic loop wiring:
5. After the loop finishes, use `For Loop End.collected` as the combined output.
`For Loop Start.index` is 1-based so it can be wired directly into prompt-builder
`row_number` inputs. `For Loop Start.skip` skips the first N iterations while
keeping the remaining row numbers stable. For example, `total=10` and `skip=1`
runs indexes `2..10`; `skip=5` runs indexes `6..10`. This is useful when you
want to resume a loop without changing index-derived seeds or row numbers.
`row_number` inputs. `For Loop Start.schedule` is an optional input for choosing
which indexes run while keeping row numbers stable. Omit it to run `1..total`,
connect a list such as `[2, 5, 8]`, or connect text such as `2,5,8` or `2-8`.
Indexes outside `1..total` are ignored.
`collection_mode` controls how values are stored:
+163 -11
View File
@@ -634,6 +634,131 @@ def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch"
return _as_list(collection) + [value]
def _coerce_loop_int(value: Any) -> int | None:
if value is None or isinstance(value, bool):
return None
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value) if value.is_integer() else None
text = str(value).strip()
if re.fullmatch(r"-?\d+(?:\.0+)?", text):
return int(float(text))
return None
def _raw_loop_schedule_values(schedule: Any) -> list[Any]:
if schedule is None:
return []
if hasattr(schedule, "tolist"):
try:
return _raw_loop_schedule_values(schedule.tolist())
except Exception:
pass
if isinstance(schedule, str):
text = schedule.strip()
if not text:
return []
try:
loaded = json.loads(text)
except Exception:
loaded = None
else:
return _raw_loop_schedule_values(loaded)
values: list[int] = []
def add_range(match: re.Match[str]) -> str:
start = int(match.group(1))
end = int(match.group(2))
step = 1 if end >= start else -1
values.extend(range(start, end + step, step))
return " "
remainder = re.sub(r"(?<!\d)(\d+)\s*(?:\.\.|-|:)\s*(\d+)(?!\d)", add_range, text)
values.extend(int(match.group(0)) for match in re.finditer(r"-?\d+", remainder))
return values
if isinstance(schedule, dict):
for key in ("schedule", "indexes", "indices", "rows", "values", "items"):
if key in schedule:
return _raw_loop_schedule_values(schedule[key])
values: list[Any] = []
for item in schedule.values():
values.extend(_raw_loop_schedule_values(item))
return values
if isinstance(schedule, (list, tuple, set)):
values = []
for item in schedule:
values.extend(_raw_loop_schedule_values(item))
return values
value = _coerce_loop_int(schedule)
return [] if value is None else [value]
def _explicit_loop_schedule(schedule: Any, total: int) -> list[int] | None:
if schedule is None:
return None
if isinstance(schedule, str) and not schedule.strip():
return None
total = max(1, int(total))
seen: set[int] = set()
values = []
for raw_value in _raw_loop_schedule_values(schedule):
value = _coerce_loop_int(raw_value)
if value is None or value < 1 or value > total or value in seen:
continue
seen.add(value)
values.append(value)
return values
def _first_loop_index(total: int, schedule: Any = None) -> int:
total = max(1, int(total))
explicit = _explicit_loop_schedule(schedule, total)
if explicit is not None:
return explicit[0] if explicit else total + 1
return 1
def _loop_index_active(index: Any, total: int, schedule: Any = None) -> bool:
total = max(1, int(total))
value = _coerce_loop_int(index)
if value is None:
return False
explicit = _explicit_loop_schedule(schedule, total)
if explicit is not None:
return value in explicit
return 1 <= value <= total
def _next_loop_index(current_index: Any, total: int, schedule: Any = None) -> tuple[int, bool]:
total = max(1, int(total))
current = _coerce_loop_int(current_index)
if current is None:
current = 0
explicit = _explicit_loop_schedule(schedule, total)
if explicit is None:
next_index = current + 1
return next_index, next_index <= total
if not explicit:
return total + 1, False
try:
position = explicit.index(current)
except ValueError:
for value in explicit:
if value > current:
return value, True
return total + 1, False
next_position = position + 1
if next_position >= len(explicit):
return total + 1, False
return explicit[next_position], True
class SxCPWhileLoopStart:
@classmethod
def INPUT_TYPES(cls):
@@ -795,10 +920,10 @@ class SxCPForLoopStart:
return {
"required": {
"total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}),
"skip": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}),
},
"optional": {
f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1)
"schedule": (ANY_TYPE,),
**{f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1)},
},
"hidden": {
"initial_index": (ANY_TYPE,),
@@ -814,12 +939,10 @@ class SxCPForLoopStart:
FUNCTION = "start"
CATEGORY = "prompt_builder/loop"
def start(self, total, skip=0, initial_index=None, initial_collected=None, **kwargs):
def start(self, total, schedule=None, initial_index=None, initial_collected=None, **kwargs):
_require_graph_builder()
total = max(1, int(total))
skip = max(0, int(skip))
first_index = skip + 1
index = first_index if initial_index is None else max(int(initial_index), first_index)
index = _first_loop_index(total, schedule=schedule) if initial_index is None else int(initial_index)
collected = initial_collected
initial_values = {
"initial_value0": index,
@@ -828,7 +951,7 @@ class SxCPForLoopStart:
for carry_index in range(1, MAX_CARRY_VALUES + 1):
initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
graph = GraphBuilder()
graph.node("SxCPWhileLoopStart", condition=index <= total, **initial_values)
graph.node("SxCPWhileLoopStart", condition=_loop_index_active(index, total, schedule=schedule), **initial_values)
return {
"result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]),
"expand": graph.finalize(),
@@ -1281,9 +1404,14 @@ class SxCPForLoopEnd:
start_node = dynprompt.get_node(loop_start)
if start_node["class_type"] != "SxCPForLoopStart":
raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.")
total = start_node["inputs"]["total"]
next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1)
condition = graph.node("SxCPLoopLessThanOrEqual", a=next_index.out(0), b=total)
start_inputs = start_node["inputs"]
total = start_inputs["total"]
next_index = graph.node(
"SxCPLoopNextIndex",
current_index=[loop_start, 1],
total=total,
schedule=start_inputs.get("schedule"),
)
collection = kwargs.get("collected") or [loop_start, 2]
collect_value = kwargs.get("collect_value")
next_collection = graph.node(
@@ -1299,13 +1427,35 @@ class SxCPForLoopEnd:
}
for carry_index in range(1, MAX_CARRY_VALUES + 1):
next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}")
while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values)
while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=next_index.out(1), **next_values)
return {
"result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)),
"expand": graph.finalize(),
}
class SxCPLoopNextIndex:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"current_index": ("INT", {"default": 1}),
"total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}),
},
"optional": {
"schedule": (ANY_TYPE,),
},
}
RETURN_TYPES = ("INT", "BOOLEAN")
RETURN_NAMES = ("index", "condition")
FUNCTION = "next_index"
CATEGORY = "prompt_builder/loop/internal"
def next_index(self, current_index, total, schedule=None):
return _next_loop_index(current_index, total, schedule=schedule)
class SxCPLoopIntAdd:
@classmethod
def INPUT_TYPES(cls):
@@ -1372,6 +1522,7 @@ LOOP_NODE_CLASS_MAPPINGS = {
"SxCPAccumulator": SxCPAccumulator,
"SxCPAccumulatorPreview": SxCPAccumulatorPreview,
"SxCPPreviewAnyAsText": SxCPPreviewAnyAsText,
"SxCPLoopNextIndex": SxCPLoopNextIndex,
"SxCPLoopIntAdd": SxCPLoopIntAdd,
"SxCPLoopLessThan": SxCPLoopLessThan,
"SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual,
@@ -1387,6 +1538,7 @@ LOOP_NODE_DISPLAY_NAME_MAPPINGS = {
"SxCPAccumulator": "SxCP Accumulator",
"SxCPAccumulatorPreview": "SxCP Accumulator Preview",
"SxCPPreviewAnyAsText": "SxCP Preview Any As Text",
"SxCPLoopNextIndex": "SxCP Loop Next Index",
"SxCPLoopIntAdd": "SxCP Loop Int Add",
"SxCPLoopLessThan": "SxCP Loop Less Than",
"SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal",
+7 -2
View File
@@ -174,7 +174,7 @@ COMMON_INPUT_TOOLTIPS = {
"custom_hardcore_clothing": "One custom hardcore clothing/body exposure state per line.",
"condition": "Loop condition. When false, the loop stops and passes current values through.",
"total": "Total number of loop iterations.",
"skip": "Number of leading loop indexes to skip. skip=1 starts generation at index 2.",
"schedule": "Optional loop index schedule. Connect a list or text like 1,3,5 or 2-6; omitted runs 1 through total.",
"collection": "Existing accumulated value or batch.",
"value": "Value to append, store, or pass through.",
"store_key": "Accumulator memory key. Leave blank for node-local storage, or use the same text to share one store across nodes.",
@@ -427,9 +427,14 @@ NODE_INPUT_TOOLTIPS = {
"include_trigger": "Keep this true for LoRA/training captions so the trigger token is learned.",
},
"SxCPForLoopStart": {
"index": "Output loop index. First generated index is skip + 1.",
"schedule": "Optional 1-based indexes to run. Accepts lists, JSON arrays, comma-separated text, and ranges like 2-6.",
"index": "Output loop index. With a schedule, this follows the scheduled 1-based indexes.",
"collected": "Current accumulated value carried through the loop.",
},
"SxCPLoopNextIndex": {
"current_index": "Current loop index used to choose the next scheduled index.",
"schedule": "Optional 1-based indexes to run. Omitted advances by one until total.",
},
"SxCPLoopAppend": {
"mode": "auto_batch tries tensor/latent batching first, then falls back to a list.",
},
+16
View File
@@ -7752,6 +7752,22 @@ def smoke_node_runtime_contracts() -> None:
"Node class and display registries are out of sync",
)
_expect(len(node_names) >= 50, "Node registry unexpectedly small")
_expect(
loop_nodes._explicit_loop_schedule("1,3,5", 5) == [1, 3, 5],
"Loop schedule should parse comma-separated indexes",
)
_expect(
loop_nodes._explicit_loop_schedule("2-4", 5) == [2, 3, 4],
"Loop schedule should expand inclusive ranges",
)
_expect(
loop_nodes._next_loop_index(4, 10, schedule="4,2") == (2, True),
"Loop schedule should preserve explicit order",
)
_expect(
loop_nodes._next_loop_index(2, 10, schedule="4,2") == (11, False),
"Loop schedule should stop after the last scheduled index",
)
for node_name in node_names:
node_class = sxcp_nodes.NODE_CLASS_MAPPINGS[node_name]