From d8dbc4fb4b66266202f1cede8d0a16773a45cee1 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 17:42:54 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20ImageGate=20node=20=E2=80=94=20pause,?= =?UTF-8?q?=20route=20via=20ExecutionBlocker,=20mask=20out?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.8 --- gates/gate.py | 42 ++++++++++++++++++++++++++++++++++++++++++ tests/test_gate.py | 10 ++++++++++ 2 files changed, 52 insertions(+) diff --git a/gates/gate.py b/gates/gate.py index 6ac5e29..ae62b52 100644 --- a/gates/gate.py +++ b/gates/gate.py @@ -22,3 +22,45 @@ def mask_from_stash(data, image): m = Image.open(io.BytesIO(data)).convert("L") arr = np.array(m, dtype=np.float32) / 255.0 return torch.from_numpy(arr).unsqueeze(0) + + +class ImageGate: + CATEGORY = "Datasete Gates" + FUNCTION = "run" + RETURN_TYPES = ("MASK",) + ("IMAGE",) * MAX_ROUTES + RETURN_NAMES = ("mask",) + tuple(f"route_{i + 1}" for i in range(MAX_ROUTES)) + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "routes": ("INT", {"default": 2, "min": 1, "max": MAX_ROUTES}), + }, + "hidden": {"unique_id": "UNIQUE_ID"}, + } + + @classmethod + def IS_CHANGED(cls, **kwargs): + return float("nan") # always pause; never cached + + def run(self, image, routes, unique_id): + from comfy_execution.graph_utils import ExecutionBlocker + from . import gate_server + + gate_bus.GateBus.arm(unique_id) + gate_server.send_preview(unique_id, image, routes) + try: + chosen_1 = gate_bus.GateBus.wait(unique_id) + except gate_bus.GateCancelled: + import comfy.model_management as mm + raise mm.InterruptProcessingException() + + mask = mask_from_stash(gate_bus.GateBus.pop_mask(unique_id), image) + chosen = max(0, min(chosen_1 - 1, routes - 1)) + blocker = ExecutionBlocker(None) + return (mask,) + route_tuple(chosen, image, blocker, MAX_ROUTES) + + +NODE_CLASS_MAPPINGS = {"ImageGate": ImageGate} +NODE_DISPLAY_NAME_MAPPINGS = {"ImageGate": "Image Gate (Manual Router)"} diff --git a/tests/test_gate.py b/tests/test_gate.py index b0c3bdc..273427a 100644 --- a/tests/test_gate.py +++ b/tests/test_gate.py @@ -1,5 +1,6 @@ # tests/test_gate.py import io +import math import torch from PIL import Image @@ -25,3 +26,12 @@ def test_mask_from_stash_decodes_png(): img = torch.zeros((1, 6, 4, 3)) m = gate.mask_from_stash(buf.getvalue(), img) assert m.shape == (1, 6, 4) and float(m.min()) > 0.99 + +def test_is_changed_always_nan(): + v = gate.ImageGate.IS_CHANGED(image=None, routes=2, unique_id="1") + assert math.isnan(v) + +def test_return_types_shape(): + assert gate.ImageGate.RETURN_TYPES[0] == "MASK" + assert len(gate.ImageGate.RETURN_TYPES) == gate.MAX_ROUTES + 1 + assert all(t == "IMAGE" for t in gate.ImageGate.RETURN_TYPES[1:])