diff --git a/__init__.py b/__init__.py index 741f569..eab1cbd 100644 --- a/__init__.py +++ b/__init__.py @@ -127,6 +127,12 @@ COMMON_INPUT_TOOLTIPS = { "save_path": "Folder to save the accumulator batch. Relative paths are inside ComfyUI output; absolute paths are used directly.", "filename_prefix": "Filename prefix for saved accumulator images.", "clear_after_save": "Clear the accumulator store after a successful batch save.", + "mode": "Switch direction: pick_input selects one input to value, route_output sends route_value to one output.", + "index": "Index used by SxCP Index Switch. For Loop Start outputs one_based indexes by default.", + "index_base": "one_based means index 1 selects input_1. zero_based means index 0 selects input_1.", + "missing_behavior": "What to do when the requested switch input is not connected: use fallback, output none, clamp, or wrap.", + "fallback": "Optional value used by SxCP Index Switch when the requested input is missing and missing_behavior is fallback.", + "route_value": "Value routed to output_N when mode is route_output.", "clothing": "Built-in clothing density for legacy direct generation. Category/profile nodes can override this.", "poses": "Built-in pose pool for legacy direct generation.", "backside_bias": "Legacy bias toward rear/backside poses where that category supports it.", @@ -281,6 +287,10 @@ def _tooltip_for_input(node_name: str, input_name: str) -> str: return f"Include {value} in this random pool." if input_name.startswith("initial_value"): return "Carry value passed into the loop body and returned on the matching output." + if re.match(r"^input_\d+$", input_name): + return "Autoscaling switch input. Connect the last visible input to reveal the next one." + if re.match(r"^output_\d+$", input_name): + return "Autoscaling routed output. Connect the last visible output to reveal the next one." if input_name.startswith("override_"): return "Optional loaded-profile override. Leave empty or keep_profile to preserve the profile value." return "" diff --git a/loop_nodes.py b/loop_nodes.py index c16ad1e..88d69e5 100644 --- a/loop_nodes.py +++ b/loop_nodes.py @@ -41,11 +41,15 @@ except Exception: MAX_LOOP_VALUES = 20 MAX_CARRY_VALUES = MAX_LOOP_VALUES - 2 +MAX_SWITCH_INPUTS = 64 COLLECTION_MODES = ["auto_batch", "list", "image_batch", "latent_batch", "string_lines"] ACCUMULATOR_ACTIONS = ["append_variant", "replace_by_entry_id", "append", "clear_then_append", "clear", "read"] ACCUMULATOR_IMAGE_BATCH_MODES = ["same_size_only", "resize_to_first"] ACCUMULATOR_IMAGE_GROUPS = 4 ACCUMULATOR_PREVIEW_DELETE_ACTIONS = ["none", "delete_entry_id", "delete_index", "clear"] +INDEX_SWITCH_MODES = ["pick_input", "route_output"] +INDEX_SWITCH_BASES = ["one_based", "zero_based"] +INDEX_SWITCH_MISSING_BEHAVIORS = ["fallback", "none", "clamp", "wrap"] _ACCUMULATOR_STORES: dict[str, list[dict[str, Any]]] = {} @@ -431,6 +435,44 @@ def append_collected_value(collection: Any, value: Any, mode: str = "auto_batch" return _as_list(collection) + [value] +def _switch_available_indices(kwargs: dict[str, Any]) -> list[int]: + indices = [] + for key in kwargs: + match = re.match(r"^input_(\d+)$", str(key)) + if match: + indices.append(int(match.group(1))) + return sorted(set(indices)) + + +def _switch_requested_index(index: Any, index_base: str) -> int: + requested = int(index) + return requested + 1 if index_base == "zero_based" else requested + + +def _switch_resolved_index(requested: int, available: list[int], missing_behavior: str) -> int | None: + if requested in available: + return requested + if missing_behavior in ("fallback", "none") or not available: + return None + if missing_behavior == "wrap": + return available[(requested - 1) % len(available)] + if requested <= available[0]: + return available[0] + if requested >= available[-1]: + return available[-1] + lower = [value for value in available if value <= requested] + return lower[-1] if lower else available[0] + + +def _switch_status(requested: int, selected: int | None, used_fallback: bool, available: list[int]) -> str: + available_text = ",".join(str(index) for index in available) or "none" + if used_fallback: + return f"requested=input_{requested}; selected=fallback; available={available_text}" + if selected is None: + return f"requested=input_{requested}; selected=none; available={available_text}" + return f"requested=input_{requested}; selected=input_{selected}; available={available_text}" + + class SxCPWhileLoopStart: @classmethod def INPUT_TYPES(cls): @@ -655,6 +697,99 @@ class SxCPLoopAppend: return (append_collected_value(collection, value, mode=mode, skip_none=skip_none),) +class SxCPIndexSwitch: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": 1, "min": -100000, "max": 100000, "step": 1}), + "mode": (INDEX_SWITCH_MODES, {"default": "pick_input"}), + "index_base": (INDEX_SWITCH_BASES, {"default": "one_based"}), + "missing_behavior": (INDEX_SWITCH_MISSING_BEHAVIORS, {"default": "fallback"}), + }, + "optional": { + "fallback": (ANY_TYPE, {"lazy": True}), + "route_value": (ANY_TYPE, {"lazy": True}), + **{ + f"input_{index}": (ANY_TYPE, {"lazy": True}) + for index in range(1, MAX_SWITCH_INPUTS + 1) + }, + }, + } + + RETURN_TYPES = tuple([ANY_TYPE, "INT", "STRING"] + [ANY_TYPE] * MAX_SWITCH_INPUTS) + RETURN_NAMES = tuple(["value", "selected_index", "status"] + [f"output_{index}" for index in range(1, MAX_SWITCH_INPUTS + 1)]) + FUNCTION = "switch" + CATEGORY = "prompt_builder/loop" + + def _input_selection( + self, + index: Any, + index_base: str, + missing_behavior: str, + kwargs: dict[str, Any], + ) -> tuple[int, int | None, list[int]]: + index_base = index_base if index_base in INDEX_SWITCH_BASES else "one_based" + missing_behavior = missing_behavior if missing_behavior in INDEX_SWITCH_MISSING_BEHAVIORS else "fallback" + requested = _switch_requested_index(index, index_base) + available = _switch_available_indices(kwargs) + selected = _switch_resolved_index(requested, available, missing_behavior) + return requested, selected, available + + def _route_selection(self, index: Any, index_base: str, missing_behavior: str) -> tuple[int, int | None]: + index_base = index_base if index_base in INDEX_SWITCH_BASES else "one_based" + missing_behavior = missing_behavior if missing_behavior in INDEX_SWITCH_MISSING_BEHAVIORS else "fallback" + requested = _switch_requested_index(index, index_base) + if 1 <= requested <= MAX_SWITCH_INPUTS: + return requested, requested + if missing_behavior == "wrap": + return requested, ((requested - 1) % MAX_SWITCH_INPUTS) + 1 + if missing_behavior == "clamp": + return requested, min(max(requested, 1), MAX_SWITCH_INPUTS) + return requested, None + + def _blocked_outputs(self) -> list[Any]: + return [_execution_blocker() for _index in range(MAX_SWITCH_INPUTS)] + + def check_lazy_status(self, index, mode, index_base, missing_behavior, **kwargs): + mode = mode if mode in INDEX_SWITCH_MODES else "pick_input" + if mode == "route_output": + return ["route_value"] if "route_value" in kwargs else [] + requested, selected, _available = self._input_selection(index, index_base, missing_behavior, kwargs) + selected_name = f"input_{selected}" if selected is not None else f"input_{requested}" + if selected_name in kwargs: + return [selected_name] + if missing_behavior == "fallback" and "fallback" in kwargs: + return ["fallback"] + return [] + + def switch(self, index, mode, index_base, missing_behavior, **kwargs): + mode = mode if mode in INDEX_SWITCH_MODES else "pick_input" + missing_behavior = missing_behavior if missing_behavior in INDEX_SWITCH_MISSING_BEHAVIORS else "fallback" + if mode == "route_output": + requested, selected = self._route_selection(index, index_base, missing_behavior) + value = kwargs.get("route_value") + outputs = self._blocked_outputs() + if selected is not None and "route_value" in kwargs: + outputs[selected - 1] = value + status = f"mode=route_output; requested=output_{requested}; selected={'none' if selected is None else f'output_{selected}'}; range=1-{MAX_SWITCH_INPUTS}" + selected_index = selected or 0 + return tuple([value if "route_value" in kwargs else None, selected_index, status] + outputs) + + requested, selected, available = self._input_selection(index, index_base, missing_behavior, kwargs) + if selected is not None: + selected_name = f"input_{selected}" + if selected_name in kwargs: + value = kwargs.get(selected_name) + status = f"mode=pick_input; {_switch_status(requested, selected, False, available)}" + return tuple([value, selected, status] + self._blocked_outputs()) + if missing_behavior == "fallback" and "fallback" in kwargs: + status = f"mode=pick_input; {_switch_status(requested, None, True, available)}" + return tuple([kwargs.get("fallback"), 0, status] + self._blocked_outputs()) + status = f"mode=pick_input; {_switch_status(requested, None, False, available)}" + return tuple([None, 0, status] + self._blocked_outputs()) + + class SxCPAccumulator: @classmethod def INPUT_TYPES(cls): @@ -1049,6 +1184,7 @@ LOOP_NODE_CLASS_MAPPINGS = { "SxCPForLoopStart": SxCPForLoopStart, "SxCPForLoopEnd": SxCPForLoopEnd, "SxCPLoopAppend": SxCPLoopAppend, + "SxCPIndexSwitch": SxCPIndexSwitch, "SxCPAccumulator": SxCPAccumulator, "SxCPAccumulatorPreview": SxCPAccumulatorPreview, "SxCPLoopIntAdd": SxCPLoopIntAdd, @@ -1062,6 +1198,7 @@ LOOP_NODE_DISPLAY_NAME_MAPPINGS = { "SxCPForLoopStart": "SxCP For Loop Start", "SxCPForLoopEnd": "SxCP For Loop End", "SxCPLoopAppend": "SxCP Loop Append", + "SxCPIndexSwitch": "SxCP Index Switch", "SxCPAccumulator": "SxCP Accumulator", "SxCPAccumulatorPreview": "SxCP Accumulator Preview", "SxCPLoopIntAdd": "SxCP Loop Int Add", diff --git a/web/index_switch_slots.js b/web/index_switch_slots.js new file mode 100644 index 0000000..2feea3b --- /dev/null +++ b/web/index_switch_slots.js @@ -0,0 +1,147 @@ +import { app } from "../../scripts/app.js"; + +const EXTENSION = "ethanfel.prompt_builder.index_switch_slots"; +const NODE_NAME = "SxCPIndexSwitch"; +const MAX_INPUTS = 64; + +function isSwitchInput(input) { + return /^input_\d+$/.test(input?.name || ""); +} + +function isSwitchOutput(output) { + return /^output_\d+$/.test(output?.name || ""); +} + +function slotNumber(slot) { + const match = String(slot?.name || "").match(/\d+$/); + return match ? Number(match[0]) : -1; +} + +function resizeNode(node) { + const size = node.computeSize?.(); + if (size) node.setSize?.(size); + app.graph?.setDirtyCanvas(true, true); +} + +function addSwitchInput(node, number) { + if (number < 1 || number > MAX_INPUTS) return; + const name = `input_${number}`; + if (!node.inputs?.some((input) => input.name === name)) { + node.addInput(name, "*"); + } +} + +function removeSwitchInput(node, number) { + const inputIndex = node.inputs?.findIndex((input) => input.name === `input_${number}`) ?? -1; + if (inputIndex >= 0 && !node.inputs[inputIndex]?.link) { + node.removeInput(inputIndex); + } +} + +function addSwitchOutput(node, number) { + if (number < 1 || number > MAX_INPUTS) return; + const name = `output_${number}`; + if (!node.outputs?.some((output) => output.name === name)) { + node.addOutput(name, "*"); + } +} + +function removeSwitchOutput(node, number) { + const outputIndex = node.outputs?.findIndex((output) => output.name === `output_${number}`) ?? -1; + if (outputIndex >= 0 && !(node.outputs[outputIndex]?.links?.length)) { + node.removeOutput(outputIndex); + } +} + +function trimInputTail(node) { + for (let number = MAX_INPUTS; number > 1; number--) { + const input = node.inputs?.find((slot) => slot.name === `input_${number}`); + const previous = node.inputs?.find((slot) => slot.name === `input_${number - 1}`); + if (!input?.link && !previous?.link) removeSwitchInput(node, number); + } +} + +function trimOutputTail(node) { + for (let number = MAX_INPUTS; number > 1; number--) { + const output = node.outputs?.find((slot) => slot.name === `output_${number}`); + const previous = node.outputs?.find((slot) => slot.name === `output_${number - 1}`); + if (!(output?.links?.length) && !(previous?.links?.length)) removeSwitchOutput(node, number); + } +} + +function setupNodeSlots(node) { + addSwitchInput(node, 1); + addSwitchOutput(node, 1); + for (let number = 2; number <= MAX_INPUTS; number++) { + const input = node.inputs?.find((slot) => slot.name === `input_${number}`); + if (!input?.link) removeSwitchInput(node, number); + const output = node.outputs?.find((slot) => slot.name === `output_${number}`); + if (!(output?.links?.length)) removeSwitchOutput(node, number); + } + trimInputTail(node); + trimOutputTail(node); + resizeNode(node); +} + +function maybeGrowInput(node) { + const switchInputs = (node.inputs || []).filter(isSwitchInput); + const last = switchInputs.reduce((max, input) => Math.max(max, slotNumber(input)), 1); + const lastInput = node.inputs?.find((slot) => slot.name === `input_${last}`); + if (lastInput?.link && last < MAX_INPUTS) { + addSwitchInput(node, last + 1); + resizeNode(node); + } +} + +function maybeGrowOutput(node) { + const switchOutputs = (node.outputs || []).filter(isSwitchOutput); + const last = switchOutputs.reduce((max, output) => Math.max(max, slotNumber(output)), 1); + const lastOutput = node.outputs?.find((slot) => slot.name === `output_${last}`); + if (lastOutput?.links?.length && last < MAX_INPUTS) { + addSwitchOutput(node, last + 1); + resizeNode(node); + } +} + +app.registerExtension({ + name: EXTENSION, + + async beforeRegisterNodeDef(nodeType, nodeData) { + if (nodeData.name !== NODE_NAME) return; + + const onNodeCreated = nodeType.prototype.onNodeCreated; + nodeType.prototype.onNodeCreated = function () { + const result = onNodeCreated?.apply(this, arguments); + queueMicrotask(() => setupNodeSlots(this)); + return result; + }; + + const onConfigure = nodeType.prototype.onConfigure; + nodeType.prototype.onConfigure = function () { + const result = onConfigure?.apply(this, arguments); + queueMicrotask(() => setupNodeSlots(this)); + return result; + }; + + const onConnectionsChange = nodeType.prototype.onConnectionsChange; + nodeType.prototype.onConnectionsChange = function (type, index, connected, linkInfo) { + const result = onConnectionsChange?.apply(this, arguments); + if (!linkInfo) return result; + const slot = type === LiteGraph.INPUT ? this.inputs?.[index] : this.outputs?.[index]; + if (type === LiteGraph.INPUT && isSwitchInput(slot)) { + if (connected) maybeGrowInput(this); + else { + trimInputTail(this); + resizeNode(this); + } + } else if (isSwitchOutput(slot)) { + if (connected) maybeGrowOutput(this); + else { + trimOutputTail(this); + resizeNode(this); + } + } + return result; + }; + }, +});