feat: ImageGate node — pause, route via ExecutionBlocker, mask out
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -22,3 +22,45 @@ def mask_from_stash(data, image):
|
|||||||
m = Image.open(io.BytesIO(data)).convert("L")
|
m = Image.open(io.BytesIO(data)).convert("L")
|
||||||
arr = np.array(m, dtype=np.float32) / 255.0
|
arr = np.array(m, dtype=np.float32) / 255.0
|
||||||
return torch.from_numpy(arr).unsqueeze(0)
|
return torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGate:
|
||||||
|
CATEGORY = "Datasete Gates"
|
||||||
|
FUNCTION = "run"
|
||||||
|
RETURN_TYPES = ("MASK",) + ("IMAGE",) * MAX_ROUTES
|
||||||
|
RETURN_NAMES = ("mask",) + tuple(f"route_{i + 1}" for i in range(MAX_ROUTES))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"routes": ("INT", {"default": 2, "min": 1, "max": MAX_ROUTES}),
|
||||||
|
},
|
||||||
|
"hidden": {"unique_id": "UNIQUE_ID"},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, **kwargs):
|
||||||
|
return float("nan") # always pause; never cached
|
||||||
|
|
||||||
|
def run(self, image, routes, unique_id):
|
||||||
|
from comfy_execution.graph_utils import ExecutionBlocker
|
||||||
|
from . import gate_server
|
||||||
|
|
||||||
|
gate_bus.GateBus.arm(unique_id)
|
||||||
|
gate_server.send_preview(unique_id, image, routes)
|
||||||
|
try:
|
||||||
|
chosen_1 = gate_bus.GateBus.wait(unique_id)
|
||||||
|
except gate_bus.GateCancelled:
|
||||||
|
import comfy.model_management as mm
|
||||||
|
raise mm.InterruptProcessingException()
|
||||||
|
|
||||||
|
mask = mask_from_stash(gate_bus.GateBus.pop_mask(unique_id), image)
|
||||||
|
chosen = max(0, min(chosen_1 - 1, routes - 1))
|
||||||
|
blocker = ExecutionBlocker(None)
|
||||||
|
return (mask,) + route_tuple(chosen, image, blocker, MAX_ROUTES)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {"ImageGate": ImageGate}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"ImageGate": "Image Gate (Manual Router)"}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# tests/test_gate.py
|
# tests/test_gate.py
|
||||||
import io
|
import io
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -25,3 +26,12 @@ def test_mask_from_stash_decodes_png():
|
|||||||
img = torch.zeros((1, 6, 4, 3))
|
img = torch.zeros((1, 6, 4, 3))
|
||||||
m = gate.mask_from_stash(buf.getvalue(), img)
|
m = gate.mask_from_stash(buf.getvalue(), img)
|
||||||
assert m.shape == (1, 6, 4) and float(m.min()) > 0.99
|
assert m.shape == (1, 6, 4) and float(m.min()) > 0.99
|
||||||
|
|
||||||
|
def test_is_changed_always_nan():
|
||||||
|
v = gate.ImageGate.IS_CHANGED(image=None, routes=2, unique_id="1")
|
||||||
|
assert math.isnan(v)
|
||||||
|
|
||||||
|
def test_return_types_shape():
|
||||||
|
assert gate.ImageGate.RETURN_TYPES[0] == "MASK"
|
||||||
|
assert len(gate.ImageGate.RETURN_TYPES) == gate.MAX_ROUTES + 1
|
||||||
|
assert all(t == "IMAGE" for t in gate.ImageGate.RETURN_TYPES[1:])
|
||||||
|
|||||||
Reference in New Issue
Block a user