diff --git a/README.md b/README.md index 3b5d68f..dc3d77a 100644 --- a/README.md +++ b/README.md @@ -191,10 +191,10 @@ Basic loop wiring: 5. After the loop finishes, use `For Loop End.collected` as the combined output. `For Loop Start.index` is 1-based so it can be wired directly into prompt-builder -`row_number` inputs. `For Loop Start.skip` skips the first N iterations while -keeping the remaining row numbers stable. For example, `total=10` and `skip=1` -runs indexes `2..10`; `skip=5` runs indexes `6..10`. This is useful when you -want to resume a loop without changing index-derived seeds or row numbers. +`row_number` inputs. `For Loop Start.schedule` is an optional input for choosing +which indexes run while keeping row numbers stable. Omit it to run `1..total`, +connect a list such as `[2, 5, 8]`, or connect text such as `2,5,8` or `2-8`. +Indexes outside `1..total` are ignored. `collection_mode` controls how values are stored: diff --git a/loop_nodes.py b/loop_nodes.py index 4edaf69..16501e2 100644 --- a/loop_nodes.py +++ b/loop_nodes.py @@ -634,6 +634,131 @@ def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch" return _as_list(collection) + [value] +def _coerce_loop_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) if value.is_integer() else None + text = str(value).strip() + if re.fullmatch(r"-?\d+(?:\.0+)?", text): + return int(float(text)) + return None + + +def _raw_loop_schedule_values(schedule: Any) -> list[Any]: + if schedule is None: + return [] + if hasattr(schedule, "tolist"): + try: + return _raw_loop_schedule_values(schedule.tolist()) + except Exception: + pass + if isinstance(schedule, str): + text = schedule.strip() + if not text: + return [] + try: + loaded = json.loads(text) + except Exception: + loaded = None + else: + return _raw_loop_schedule_values(loaded) + + values: list[int] = [] + + def add_range(match: re.Match[str]) -> str: + start = int(match.group(1)) + end = int(match.group(2)) + step = 1 if end >= start else -1 + values.extend(range(start, end + step, step)) + return " " + + remainder = re.sub(r"(? list[int] | None: + if schedule is None: + return None + if isinstance(schedule, str) and not schedule.strip(): + return None + + total = max(1, int(total)) + seen: set[int] = set() + values = [] + for raw_value in _raw_loop_schedule_values(schedule): + value = _coerce_loop_int(raw_value) + if value is None or value < 1 or value > total or value in seen: + continue + seen.add(value) + values.append(value) + return values + + +def _first_loop_index(total: int, schedule: Any = None) -> int: + total = max(1, int(total)) + explicit = _explicit_loop_schedule(schedule, total) + if explicit is not None: + return explicit[0] if explicit else total + 1 + return 1 + + +def _loop_index_active(index: Any, total: int, schedule: Any = None) -> bool: + total = max(1, int(total)) + value = _coerce_loop_int(index) + if value is None: + return False + explicit = _explicit_loop_schedule(schedule, total) + if explicit is not None: + return value in explicit + return 1 <= value <= total + + +def _next_loop_index(current_index: Any, total: int, schedule: Any = None) -> tuple[int, bool]: + total = max(1, int(total)) + current = _coerce_loop_int(current_index) + if current is None: + current = 0 + + explicit = _explicit_loop_schedule(schedule, total) + if explicit is None: + next_index = current + 1 + return next_index, next_index <= total + if not explicit: + return total + 1, False + + try: + position = explicit.index(current) + except ValueError: + for value in explicit: + if value > current: + return value, True + return total + 1, False + + next_position = position + 1 + if next_position >= len(explicit): + return total + 1, False + return explicit[next_position], True + + class SxCPWhileLoopStart: @classmethod def INPUT_TYPES(cls): @@ -795,10 +920,10 @@ class SxCPForLoopStart: return { "required": { "total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}), - "skip": ("INT", {"default": 0, "min": 0, "max": 100000, "step": 1}), }, "optional": { - f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1) + "schedule": (ANY_TYPE,), + **{f"initial_value{index}": (ANY_TYPE,) for index in range(1, MAX_CARRY_VALUES + 1)}, }, "hidden": { "initial_index": (ANY_TYPE,), @@ -814,12 +939,10 @@ class SxCPForLoopStart: FUNCTION = "start" CATEGORY = "prompt_builder/loop" - def start(self, total, skip=0, initial_index=None, initial_collected=None, **kwargs): + def start(self, total, schedule=None, initial_index=None, initial_collected=None, **kwargs): _require_graph_builder() total = max(1, int(total)) - skip = max(0, int(skip)) - first_index = skip + 1 - index = first_index if initial_index is None else max(int(initial_index), first_index) + index = _first_loop_index(total, schedule=schedule) if initial_index is None else int(initial_index) collected = initial_collected initial_values = { "initial_value0": index, @@ -828,7 +951,7 @@ class SxCPForLoopStart: for carry_index in range(1, MAX_CARRY_VALUES + 1): initial_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") graph = GraphBuilder() - graph.node("SxCPWhileLoopStart", condition=index <= total, **initial_values) + graph.node("SxCPWhileLoopStart", condition=_loop_index_active(index, total, schedule=schedule), **initial_values) return { "result": tuple(["stub", index, collected] + [kwargs.get(f"initial_value{index}") for index in range(1, MAX_CARRY_VALUES + 1)]), "expand": graph.finalize(), @@ -1281,9 +1404,14 @@ class SxCPForLoopEnd: start_node = dynprompt.get_node(loop_start) if start_node["class_type"] != "SxCPForLoopStart": raise ValueError("SxCP For Loop End must receive flow from SxCP For Loop Start.") - total = start_node["inputs"]["total"] - next_index = graph.node("SxCPLoopIntAdd", a=[loop_start, 1], b=1) - condition = graph.node("SxCPLoopLessThanOrEqual", a=next_index.out(0), b=total) + start_inputs = start_node["inputs"] + total = start_inputs["total"] + next_index = graph.node( + "SxCPLoopNextIndex", + current_index=[loop_start, 1], + total=total, + schedule=start_inputs.get("schedule"), + ) collection = kwargs.get("collected") or [loop_start, 2] collect_value = kwargs.get("collect_value") next_collection = graph.node( @@ -1299,13 +1427,35 @@ class SxCPForLoopEnd: } for carry_index in range(1, MAX_CARRY_VALUES + 1): next_values[f"initial_value{carry_index + 1}"] = kwargs.get(f"initial_value{carry_index}") - while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=condition.out(0), **next_values) + while_close = graph.node("SxCPWhileLoopEnd", flow=flow, condition=next_index.out(1), **next_values) return { "result": tuple(while_close.out(index) for index in range(1, MAX_LOOP_VALUES)), "expand": graph.finalize(), } +class SxCPLoopNextIndex: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "current_index": ("INT", {"default": 1}), + "total": ("INT", {"default": 2, "min": 1, "max": 100000, "step": 1}), + }, + "optional": { + "schedule": (ANY_TYPE,), + }, + } + + RETURN_TYPES = ("INT", "BOOLEAN") + RETURN_NAMES = ("index", "condition") + FUNCTION = "next_index" + CATEGORY = "prompt_builder/loop/internal" + + def next_index(self, current_index, total, schedule=None): + return _next_loop_index(current_index, total, schedule=schedule) + + class SxCPLoopIntAdd: @classmethod def INPUT_TYPES(cls): @@ -1372,6 +1522,7 @@ LOOP_NODE_CLASS_MAPPINGS = { "SxCPAccumulator": SxCPAccumulator, "SxCPAccumulatorPreview": SxCPAccumulatorPreview, "SxCPPreviewAnyAsText": SxCPPreviewAnyAsText, + "SxCPLoopNextIndex": SxCPLoopNextIndex, "SxCPLoopIntAdd": SxCPLoopIntAdd, "SxCPLoopLessThan": SxCPLoopLessThan, "SxCPLoopLessThanOrEqual": SxCPLoopLessThanOrEqual, @@ -1387,6 +1538,7 @@ LOOP_NODE_DISPLAY_NAME_MAPPINGS = { "SxCPAccumulator": "SxCP Accumulator", "SxCPAccumulatorPreview": "SxCP Accumulator Preview", "SxCPPreviewAnyAsText": "SxCP Preview Any As Text", + "SxCPLoopNextIndex": "SxCP Loop Next Index", "SxCPLoopIntAdd": "SxCP Loop Int Add", "SxCPLoopLessThan": "SxCP Loop Less Than", "SxCPLoopLessThanOrEqual": "SxCP Loop Less Than Or Equal", diff --git a/node_tooltips.py b/node_tooltips.py index 51e30ce..96f7f9b 100644 --- a/node_tooltips.py +++ b/node_tooltips.py @@ -174,7 +174,7 @@ COMMON_INPUT_TOOLTIPS = { "custom_hardcore_clothing": "One custom hardcore clothing/body exposure state per line.", "condition": "Loop condition. When false, the loop stops and passes current values through.", "total": "Total number of loop iterations.", - "skip": "Number of leading loop indexes to skip. skip=1 starts generation at index 2.", + "schedule": "Optional loop index schedule. Connect a list or text like 1,3,5 or 2-6; omitted runs 1 through total.", "collection": "Existing accumulated value or batch.", "value": "Value to append, store, or pass through.", "store_key": "Accumulator memory key. Leave blank for node-local storage, or use the same text to share one store across nodes.", @@ -427,9 +427,14 @@ NODE_INPUT_TOOLTIPS = { "include_trigger": "Keep this true for LoRA/training captions so the trigger token is learned.", }, "SxCPForLoopStart": { - "index": "Output loop index. First generated index is skip + 1.", + "schedule": "Optional 1-based indexes to run. Accepts lists, JSON arrays, comma-separated text, and ranges like 2-6.", + "index": "Output loop index. With a schedule, this follows the scheduled 1-based indexes.", "collected": "Current accumulated value carried through the loop.", }, + "SxCPLoopNextIndex": { + "current_index": "Current loop index used to choose the next scheduled index.", + "schedule": "Optional 1-based indexes to run. Omitted advances by one until total.", + }, "SxCPLoopAppend": { "mode": "auto_batch tries tensor/latent batching first, then falls back to a list.", }, diff --git a/tools/prompt_smoke.py b/tools/prompt_smoke.py index 1d17f35..61ab6fe 100644 --- a/tools/prompt_smoke.py +++ b/tools/prompt_smoke.py @@ -7752,6 +7752,22 @@ def smoke_node_runtime_contracts() -> None: "Node class and display registries are out of sync", ) _expect(len(node_names) >= 50, "Node registry unexpectedly small") + _expect( + loop_nodes._explicit_loop_schedule("1,3,5", 5) == [1, 3, 5], + "Loop schedule should parse comma-separated indexes", + ) + _expect( + loop_nodes._explicit_loop_schedule("2-4", 5) == [2, 3, 4], + "Loop schedule should expand inclusive ranges", + ) + _expect( + loop_nodes._next_loop_index(4, 10, schedule="4,2") == (2, True), + "Loop schedule should preserve explicit order", + ) + _expect( + loop_nodes._next_loop_index(2, 10, schedule="4,2") == (11, False), + "Loop schedule should stop after the last scheduled index", + ) for node_name in node_names: node_class = sxcp_nodes.NODE_CLASS_MAPPINGS[node_name]