Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ec8e1b9598 | |||
| 6e27da0dce | |||
| f9f924942e | |||
| 45e16e1134 | |||
| 63647d2488 | |||
| 8e8eb317f7 | |||
| d8dbc4fb4b | |||
| ea3438567a | |||
| f0f8676eaa | |||
| 11772bc29d | |||
| 9148dfec25 | |||
| 7e8878bade |
+5
-2
@@ -12,10 +12,13 @@ if __package__:
|
||||
NODE_DISPLAY_NAME_MAPPINGS as _POOL_NAMES
|
||||
from .gates.loader import NODE_CLASS_MAPPINGS as _LOADER_NODES, \
|
||||
NODE_DISPLAY_NAME_MAPPINGS as _LOADER_NAMES
|
||||
from .gates.gate import NODE_CLASS_MAPPINGS as _GATE_NODES, \
|
||||
NODE_DISPLAY_NAME_MAPPINGS as _GATE_NAMES
|
||||
from .gates import routes # noqa: F401 (registers aiohttp routes on import)
|
||||
from .gates import gate_server # noqa: F401 (registers /datasete_gate/* routes)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES}
|
||||
NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES, **_GATE_NODES}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES, **_GATE_NAMES}
|
||||
else: # pragma: no cover - exercised only under pytest collection
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Image Gate (Manual Router) — Design
|
||||
|
||||
Date: 2026-06-21
|
||||
Status: Approved (brainstorming complete, ready for implementation plan)
|
||||
|
||||
## 1. Purpose
|
||||
|
||||
An interactive "image chooser on steroids": during a prompt run the node **pauses**,
|
||||
shows the incoming image with a row of labeled **route buttons**, and waits for a human
|
||||
click. Clicking **route K** sends the image down output K (all other route branches are
|
||||
silently skipped). A **Stop** button cancels the whole run. Optionally, an **Edit mask**
|
||||
button opens ComfyUI's MaskEditor on the image and the painted mask is emitted on a
|
||||
single `mask` output. Built for manual dataset sorting/gating in the "Dataset Gates" suite.
|
||||
|
||||
Third node in the `ComfyUI-Datasete-Gates` package.
|
||||
|
||||
## 2. IO
|
||||
|
||||
| dir | name | type | notes |
|
||||
|---|---|---|---|
|
||||
| in | `image` | IMAGE | the image (or batch, routed as one unit) |
|
||||
| widget | `routes` | INT, default 2, 1..10 | number of visible route buttons/outputs |
|
||||
| widget | per-route labels | (frontend) | editable, default `1..N`; rename the visible output slots |
|
||||
| hidden | `unique_id` | UNIQUE_ID | node id, used to key the pause/choice |
|
||||
| out | `mask` | MASK | **fixed slot 0**; painted at the gate, zeros (sized to image) if none |
|
||||
| out | `route_1 … route_10` | IMAGE | dynamic; JS shows only `routes` of them, labeled |
|
||||
|
||||
`RETURN_TYPES = ("MASK",) + ("IMAGE",)*10`. The node always returns all 11 outputs; the
|
||||
chosen route carries the image, every other route returns `ExecutionBlocker(None)`. JS
|
||||
hides the unused slots (>`routes`); their `ExecutionBlocker` returns are harmless.
|
||||
|
||||
## 3. Behavior (the pause)
|
||||
|
||||
On execute:
|
||||
1. Push the image to the UI (`PromptServer.send_sync`, base64 or temp file) so the node
|
||||
body shows the preview + the N labeled route buttons + **🖌 Edit mask** + **■ Stop**.
|
||||
2. **Block** the executor thread on our own `GateBus.wait(unique_id)` (a `MessageHolder`-
|
||||
style singleton in a `sleep(0.1)` loop; separate namespace from cg-image-picker).
|
||||
3. Resolve:
|
||||
- **route K** → image to output `K`, `ExecutionBlocker(None)` to the other routes;
|
||||
`mask` = the painted mask (or zeros).
|
||||
- **🖌 Edit mask** → opens MaskEditor (reuse the pool node's clipspace flow); the mask
|
||||
is POSTed to `/datasete_gate/mask` keyed by `unique_id` and picked up on resume.
|
||||
- **■ Stop** → cancel the prompt cleanly via
|
||||
`comfy.model_management.InterruptProcessingException` (confirm exact symbol in plan).
|
||||
|
||||
`IS_CHANGED` returns `nan` → the gate pauses on **every** run (never cached).
|
||||
|
||||
## 4. Why the global mask is safe
|
||||
|
||||
Verified in `execution.py:257-266` + `305-306`: if **any** input of a node is an
|
||||
`ExecutionBlocker`, the node is skipped and the blocker propagates to all its outputs.
|
||||
So a non-chosen route's downstream (which consumes the blocked routed image) never runs,
|
||||
regardless of the live `mask` value. Caveat: a node wired to `mask` *only* (no routed
|
||||
image) would run unconditionally — not the intended wiring.
|
||||
|
||||
## 5. Code shape (same package)
|
||||
|
||||
- `gates/gate.py` — `ImageGate` node: `INPUT_TYPES`, `IS_CHANGED=nan`, `run()` (push
|
||||
preview → block → route via `ExecutionBlocker`). Pure helper `route_tuple(chosen, image,
|
||||
blocker, max_routes)` for unit testing.
|
||||
- `gates/gate_server.py` — `GateBus` (start/put/wait/cancel) + mask stash; aiohttp routes
|
||||
`/datasete_gate/choice` and `/datasete_gate/mask`; `send_preview()` helper.
|
||||
- `web/image_gate.js` — dynamic labeled outputs (show `routes` of 10), preview render,
|
||||
route/stop/mask buttons, posts the choice; reuses the pool's MaskEditor helper.
|
||||
|
||||
## 6. Edge cases
|
||||
|
||||
- `routes` changed between runs → JS re-syncs visible slots; Python clamps `chosen` to
|
||||
`routes`.
|
||||
- Stop while no mask painted → clean interrupt, no output.
|
||||
- Multiple gates in one graph → execute sequentially (single executor thread), so only one
|
||||
blocks at a time; still keyed by `unique_id`.
|
||||
- Batch input → previewed as the first image / small grid; routed as one unit.
|
||||
- External queue-cancel → `GateBus` honors the cancel flag and raises.
|
||||
|
||||
## 7. Testing
|
||||
|
||||
- pytest: `route_tuple` (image at chosen, blocker elsewhere, correct length); `GateBus`
|
||||
(pre-seeded message returns; cancel raises; `start` resets); mask zero-fallback sizing.
|
||||
- Manual (live): pause appears, buttons labeled, click routes image to the right branch
|
||||
and only that branch runs; Edit mask round-trips and feeds `mask`; Stop cancels cleanly;
|
||||
changing `routes` adds/removes slots.
|
||||
@@ -0,0 +1,431 @@
|
||||
# Image Gate (Manual Router) Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Build a ComfyUI custom node `Image Gate (Manual Router)` that pauses a running prompt, shows the image with up to 10 labeled route buttons + a mask-edit + a stop button, and routes the image down the clicked output (others `ExecutionBlocker`-ed), emitting any gate-painted mask on a fixed `mask` output.
|
||||
|
||||
**Architecture:** A pure, torch-free `gates/gate_bus.py` (a `MessageHolder`-style blocking waiter + mask stash) is unit-testable without ComfyUI. `gates/gate.py` holds the node plus pure helpers (`route_tuple`, `mask_from_stash`); it imports `ExecutionBlocker`/`model_management` lazily so tests don't need comfy. `gates/gate_server.py` is the aiohttp glue (choice/mask routes + `send_preview`). `web/image_gate.js` renders preview + dynamic labeled outputs + buttons and posts the choice; it reuses the pool node's MaskEditor helper.
|
||||
|
||||
**Tech Stack:** Python 3.12, torch 2.8, Pillow, numpy, aiohttp; pytest 9; vanilla JS frontend.
|
||||
|
||||
---
|
||||
|
||||
## Conventions (read once)
|
||||
|
||||
- **Test python:** `/media/p5/miniforge3/bin/python` (`PY=...`).
|
||||
- **Run tests:** `cd /media/p5/ComfyUI-Datasete-Gates && $PY -m pytest tests/test_gate_bus.py tests/test_gate.py -v`
|
||||
- **Concurrency:** other sessions may share this working tree. Stage only this node's paths
|
||||
when committing; re-Read `__init__.py` before editing (Task 6) and *extend*, don't overwrite.
|
||||
- `gates/gate_bus.py` MUST be import-safe without comfy/torch (stdlib only).
|
||||
- `gates/gate.py` MUST import `ExecutionBlocker` and `comfy.model_management` **lazily inside
|
||||
`run()`** (and `send_preview` lazily) so `import gates.gate` works under pytest.
|
||||
- Mask convention: grayscale `L`, white = painted; zeros sized to the image if none.
|
||||
- Commit style: Conventional Commits + repo Co-Authored-By trailer.
|
||||
- `MAX_ROUTES = 10`.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: `gate_bus.py` — `GateBus` (arm/put/wait/cancel)
|
||||
|
||||
**Files:** Create `gates/gate_bus.py`; Test `tests/test_gate_bus.py`
|
||||
|
||||
**Step 1: Failing test**
|
||||
|
||||
```python
|
||||
# tests/test_gate_bus.py
|
||||
import pytest
|
||||
from gates import gate_bus as gb
|
||||
|
||||
def test_put_and_wait_returns_choice():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "3")
|
||||
assert gb.GateBus.wait("7") == 3
|
||||
|
||||
def test_wait_consumes_message():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "2")
|
||||
gb.GateBus.wait("7")
|
||||
assert "7" not in gb.GateBus.messages
|
||||
|
||||
def test_cancel_raises_and_resets():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "__cancel__")
|
||||
with pytest.raises(gb.GateCancelled):
|
||||
gb.GateBus.wait("7")
|
||||
assert gb.GateBus.cancelled is False # reset after raising
|
||||
|
||||
def test_arm_clears_stale_state():
|
||||
gb.GateBus.put("1", "5")
|
||||
gb.GateBus.cancelled = True
|
||||
gb.GateBus.arm("1")
|
||||
assert "1" not in gb.GateBus.messages
|
||||
assert gb.GateBus.cancelled is False
|
||||
```
|
||||
|
||||
**Step 2: Run → FAIL.**
|
||||
|
||||
**Step 3: Implement**
|
||||
|
||||
```python
|
||||
# gates/gate_bus.py
|
||||
"""Blocking choice bus for the Image Gate node. Stdlib only — no comfy/torch."""
|
||||
import time
|
||||
|
||||
|
||||
class GateCancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class GateBus:
|
||||
messages = {} # node_id(str) -> chosen int (1-based)
|
||||
masks = {} # node_id(str) -> PNG bytes
|
||||
cancelled = False
|
||||
|
||||
@classmethod
|
||||
def arm(cls, node_id):
|
||||
cls.messages.pop(str(node_id), None)
|
||||
cls.masks.pop(str(node_id), None)
|
||||
cls.cancelled = False
|
||||
|
||||
@classmethod
|
||||
def put(cls, node_id, message):
|
||||
if message == "__cancel__":
|
||||
cls.cancelled = True
|
||||
else:
|
||||
cls.messages[str(node_id)] = int(message)
|
||||
|
||||
@classmethod
|
||||
def wait(cls, node_id, period=0.1):
|
||||
sid = str(node_id)
|
||||
while sid not in cls.messages:
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise GateCancelled()
|
||||
time.sleep(period)
|
||||
return cls.messages.pop(sid)
|
||||
```
|
||||
|
||||
**Step 4: Run → PASS.** **Step 5: Commit** `feat: gate_bus blocking choice waiter`
|
||||
|
||||
---
|
||||
|
||||
### Task 2: `gate_bus.py` — mask stash
|
||||
|
||||
**Files:** Modify `gates/gate_bus.py`, `tests/test_gate_bus.py`
|
||||
|
||||
**Step 1: Failing test**
|
||||
|
||||
```python
|
||||
def test_mask_stash_roundtrip():
|
||||
gb.GateBus.put_mask("9", b"PNGDATA")
|
||||
assert gb.GateBus.pop_mask("9") == b"PNGDATA"
|
||||
assert gb.GateBus.pop_mask("9") is None # popped
|
||||
|
||||
def test_arm_clears_mask():
|
||||
gb.GateBus.put_mask("9", b"x")
|
||||
gb.GateBus.arm("9")
|
||||
assert gb.GateBus.pop_mask("9") is None
|
||||
```
|
||||
|
||||
**Step 2: Run → FAIL.**
|
||||
|
||||
**Step 3: Implement (append to `GateBus`)**
|
||||
|
||||
```python
|
||||
@classmethod
|
||||
def put_mask(cls, node_id, data):
|
||||
cls.masks[str(node_id)] = data
|
||||
|
||||
@classmethod
|
||||
def pop_mask(cls, node_id):
|
||||
return cls.masks.pop(str(node_id), None)
|
||||
```
|
||||
|
||||
**Step 4: Run → PASS.** **Step 5: Commit** `feat: gate_bus mask stash`
|
||||
|
||||
---
|
||||
|
||||
### Task 3: `gate.py` — `route_tuple` pure helper
|
||||
|
||||
**Files:** Create `gates/gate.py`; Test `tests/test_gate.py`
|
||||
|
||||
**Step 1: Failing test**
|
||||
|
||||
```python
|
||||
# tests/test_gate.py
|
||||
from gates import gate
|
||||
|
||||
def test_route_tuple_places_image_at_chosen():
|
||||
B = object()
|
||||
t = gate.route_tuple(2, "IMG", B, max_routes=5)
|
||||
assert t == (B, B, "IMG", B, B)
|
||||
|
||||
def test_route_tuple_length_is_max():
|
||||
B = object()
|
||||
assert len(gate.route_tuple(0, "IMG", B, max_routes=10)) == 10
|
||||
```
|
||||
|
||||
**Step 2: Run → FAIL.**
|
||||
|
||||
**Step 3: Implement**
|
||||
|
||||
```python
|
||||
# gates/gate.py
|
||||
import io
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from . import gate_bus
|
||||
|
||||
MAX_ROUTES = 10
|
||||
|
||||
|
||||
def route_tuple(chosen, image, blocker, max_routes=MAX_ROUTES):
|
||||
return tuple(image if i == chosen else blocker for i in range(max_routes))
|
||||
```
|
||||
|
||||
**Step 4: Run → PASS.** **Step 5: Commit** `feat: gate route_tuple helper`
|
||||
|
||||
---
|
||||
|
||||
### Task 4: `gate.py` — `mask_from_stash`
|
||||
|
||||
**Files:** Modify `gates/gate.py`, `tests/test_gate.py`
|
||||
|
||||
**Step 1: Failing test**
|
||||
|
||||
```python
|
||||
import io, torch
|
||||
from PIL import Image
|
||||
|
||||
def test_mask_from_stash_none_is_zeros():
|
||||
img = torch.zeros((1, 6, 4, 3))
|
||||
m = gate.mask_from_stash(None, img)
|
||||
assert m.shape == (1, 6, 4) and float(m.max()) == 0.0
|
||||
|
||||
def test_mask_from_stash_decodes_png():
|
||||
buf = io.BytesIO(); Image.new("L", (4, 6), 255).save(buf, "PNG")
|
||||
img = torch.zeros((1, 6, 4, 3))
|
||||
m = gate.mask_from_stash(buf.getvalue(), img)
|
||||
assert m.shape == (1, 6, 4) and float(m.min()) > 0.99
|
||||
```
|
||||
|
||||
**Step 2: Run → FAIL.**
|
||||
|
||||
**Step 3: Implement (append)**
|
||||
|
||||
```python
|
||||
def mask_from_stash(data, image):
|
||||
b, h, w = image.shape[0], image.shape[1], image.shape[2]
|
||||
if not data:
|
||||
return torch.zeros((b, h, w), dtype=torch.float32)
|
||||
m = Image.open(io.BytesIO(data)).convert("L")
|
||||
arr = np.array(m, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
```
|
||||
|
||||
**Step 4: Run → PASS.** **Step 5: Commit** `feat: gate mask_from_stash (paint or zeros)`
|
||||
|
||||
---
|
||||
|
||||
### Task 5: `gate.py` — `ImageGate` node class
|
||||
|
||||
**Files:** Modify `gates/gate.py`, `tests/test_gate.py`
|
||||
|
||||
**Step 0: Verify the interrupt symbol** (so Stop cancels cleanly):
|
||||
`grep -n "class InterruptProcessingException\|def interrupt_current_processing" /media/p5/Comfyui/comfy/model_management.py`
|
||||
Use whatever exists (expected: `InterruptProcessingException`).
|
||||
|
||||
**Step 1: Failing test**
|
||||
|
||||
```python
|
||||
import math
|
||||
|
||||
def test_is_changed_always_nan():
|
||||
v = gate.ImageGate.IS_CHANGED(image=None, routes=2, unique_id="1")
|
||||
assert math.isnan(v)
|
||||
|
||||
def test_return_types_shape():
|
||||
assert gate.ImageGate.RETURN_TYPES[0] == "MASK"
|
||||
assert len(gate.ImageGate.RETURN_TYPES) == gate.MAX_ROUTES + 1
|
||||
assert all(t == "IMAGE" for t in gate.ImageGate.RETURN_TYPES[1:])
|
||||
```
|
||||
|
||||
**Step 2: Run → FAIL.**
|
||||
|
||||
**Step 3: Implement (append)**
|
||||
|
||||
```python
|
||||
class ImageGate:
|
||||
CATEGORY = "Datasete Gates"
|
||||
FUNCTION = "run"
|
||||
RETURN_TYPES = ("MASK",) + ("IMAGE",) * MAX_ROUTES
|
||||
RETURN_NAMES = ("mask",) + tuple(f"route_{i + 1}" for i in range(MAX_ROUTES))
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"routes": ("INT", {"default": 2, "min": 1, "max": MAX_ROUTES}),
|
||||
},
|
||||
"hidden": {"unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, **kwargs):
|
||||
return float("nan") # always pause; never cached
|
||||
|
||||
def run(self, image, routes, unique_id):
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from . import gate_server
|
||||
|
||||
gate_bus.GateBus.arm(unique_id)
|
||||
gate_server.send_preview(unique_id, image, routes)
|
||||
try:
|
||||
chosen_1 = gate_bus.GateBus.wait(unique_id)
|
||||
except gate_bus.GateCancelled:
|
||||
import comfy.model_management as mm
|
||||
raise mm.InterruptProcessingException() # confirm symbol in Step 0
|
||||
|
||||
mask = mask_from_stash(gate_bus.GateBus.pop_mask(unique_id), image)
|
||||
chosen = max(0, min(chosen_1 - 1, routes - 1))
|
||||
blocker = ExecutionBlocker(None)
|
||||
return (mask,) + route_tuple(chosen, image, blocker, MAX_ROUTES)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"ImageGate": ImageGate}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"ImageGate": "Image Gate (Manual Router)"}
|
||||
```
|
||||
|
||||
**Step 4: Run → PASS.** (`run()` itself is covered by the live smoke test, not unit tests.)
|
||||
|
||||
**Step 5: Commit** `feat: ImageGate node — pause, route via ExecutionBlocker, mask out`
|
||||
|
||||
---
|
||||
|
||||
### Task 6: `gate_server.py` — routes + preview, and register (MERGE)
|
||||
|
||||
**Files:** Create `gates/gate_server.py`; Modify `__init__.py`
|
||||
|
||||
**Step 1: Implement `gates/gate_server.py`** (aiohttp glue — verified live, not unit-tested)
|
||||
|
||||
```python
|
||||
# gates/gate_server.py
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from server import PromptServer
|
||||
|
||||
from .gate_bus import GateBus
|
||||
|
||||
routes = PromptServer.instance.routes
|
||||
|
||||
|
||||
def send_preview(node_id, image, n_routes):
|
||||
arr = (image[0].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||
buf = io.BytesIO()
|
||||
Image.fromarray(arr).save(buf, "PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
PromptServer.instance.send_sync(
|
||||
"datasete-gate-show",
|
||||
{"id": str(node_id), "image": b64, "routes": int(n_routes)},
|
||||
)
|
||||
|
||||
|
||||
@routes.post("/datasete_gate/choice")
|
||||
async def _choice(request):
|
||||
post = await request.post()
|
||||
GateBus.put(post.get("id"), post.get("message"))
|
||||
return web.json_response({})
|
||||
|
||||
|
||||
@routes.post("/datasete_gate/mask")
|
||||
async def _mask(request):
|
||||
reader = await request.multipart()
|
||||
node_id, data = None, None
|
||||
async for part in reader:
|
||||
if part.name == "id":
|
||||
node_id = await part.text()
|
||||
elif part.name == "mask":
|
||||
data = await part.read(decode=False)
|
||||
if node_id is not None:
|
||||
GateBus.put_mask(node_id, data)
|
||||
return web.json_response({})
|
||||
```
|
||||
|
||||
**Step 2: Re-Read `__init__.py`** and extend the `if __package__:` block to merge the gate
|
||||
node and import its server (registers routes):
|
||||
|
||||
```python
|
||||
from .gates.gate import NODE_CLASS_MAPPINGS as _GATE_NODES, \
|
||||
NODE_DISPLAY_NAME_MAPPINGS as _GATE_NAMES
|
||||
from .gates import gate_server # noqa: F401 (registers /datasete_gate/* routes)
|
||||
NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **_GATE_NODES}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **_GATE_NAMES}
|
||||
```
|
||||
(Adapt to the file's current merge structure; the only requirement is the gate node ends up
|
||||
in the mappings and `gate_server` is imported.)
|
||||
|
||||
**Step 3:** `$PY -c "import gates.gate; print(gates.gate.NODE_CLASS_MAPPINGS)"` → shows ImageGate.
|
||||
|
||||
**Step 4:** Full suite green: `$PY -m pytest tests/ -v`
|
||||
|
||||
**Step 5: Commit** `feat: gate server routes + preview + register ImageGate`
|
||||
|
||||
---
|
||||
|
||||
### Task 7: `web/image_gate.js` — preview, dynamic outputs, buttons
|
||||
|
||||
**Files:** Create `web/image_gate.js`
|
||||
|
||||
Implement an `app.registerExtension` for `ImageGate`:
|
||||
- **Dynamic outputs:** on `nodeCreated` and when the `routes` widget changes, show only the
|
||||
first `routes` of the 10 `route_*` outputs (hide/remove the rest); give each visible output
|
||||
an editable label (default `1..N`) persisted in `widgets_values`; keep the `mask` output
|
||||
(slot 0) always visible. (Reuse your existing dynamic-slot pattern.)
|
||||
- **Preview + buttons:** listen for the `datasete-gate-show` socket event
|
||||
(`api.addEventListener`); when it fires for this node's id, render the image in a DOM widget
|
||||
with: one button per visible route (labeled), an **🖌 Edit mask** button, and a **■ Stop**
|
||||
button.
|
||||
- **Choice:** route button → POST `/datasete_gate/choice` `{id, message: <1-based index>}`.
|
||||
Stop → POST `{id, message: "__cancel__"}`.
|
||||
- **Mask:** 🖌 → open MaskEditor on the previewed image (reuse the pool node's clipspace
|
||||
helper); on save, export the grayscale mask PNG and POST it to `/datasete_gate/mask`
|
||||
(multipart `id`, `mask`) **before** clicking a route.
|
||||
|
||||
**Manual verification (live, Task 8 covers the run):** node shows N labeled outputs that
|
||||
track the `routes` widget; labels persist across reload.
|
||||
|
||||
**Commit** `feat: image gate frontend — preview, dynamic outputs, route/stop/mask`
|
||||
|
||||
---
|
||||
|
||||
### Task 8: Live smoke test in ComfyUI
|
||||
|
||||
Restart ComfyUI (repo already symlinked into `custom_nodes`). Build: `Folder Image Loader →
|
||||
Image Gate`, wire `route_1`/`route_2` to two `PreviewImage`/`SaveImage` nodes, `mask` to a
|
||||
`MaskPreview`. Verify:
|
||||
- [ ] "Image Gate (Manual Router)" appears under "Datasete Gates".
|
||||
- [ ] Queue → execution **pauses**, image preview + labeled buttons + 🖌 + ■ appear.
|
||||
- [ ] Click route 1 → only route-1's downstream runs; route-2's does not.
|
||||
- [ ] Click route 2 → only route-2's downstream runs.
|
||||
- [ ] 🖌 Edit mask → MaskEditor opens; paint, save; then click a route → `mask` output carries the painted mask; no mask painted → zeros.
|
||||
- [ ] ■ Stop → the run cancels cleanly (no scary traceback; queue stops).
|
||||
- [ ] Change `routes` from 2→4 → two more labeled outputs appear; reload keeps labels.
|
||||
- [ ] Run twice in a row → it pauses **both** times (not cached).
|
||||
|
||||
**Commit** (if fixes) `fix: image gate live-test adjustments`
|
||||
|
||||
---
|
||||
|
||||
## Definition of done
|
||||
|
||||
- `$PY -m pytest tests/test_gate_bus.py tests/test_gate.py -v` green; full `tests/` green.
|
||||
- Manual checklist passes: pause, route isolation (ExecutionBlocker), mask round-trip, clean Stop, dynamic labeled outputs.
|
||||
@@ -0,0 +1,66 @@
|
||||
# gates/gate.py
|
||||
import io
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from . import gate_bus
|
||||
|
||||
MAX_ROUTES = 10
|
||||
|
||||
|
||||
def route_tuple(chosen, image, blocker, max_routes=MAX_ROUTES):
|
||||
return tuple(image if i == chosen else blocker for i in range(max_routes))
|
||||
|
||||
|
||||
def mask_from_stash(data, image):
|
||||
b, h, w = image.shape[0], image.shape[1], image.shape[2]
|
||||
if not data:
|
||||
return torch.zeros((b, h, w), dtype=torch.float32)
|
||||
m = Image.open(io.BytesIO(data)).convert("L")
|
||||
arr = np.array(m, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).unsqueeze(0)
|
||||
|
||||
|
||||
class ImageGate:
|
||||
CATEGORY = "Datasete Gates"
|
||||
FUNCTION = "run"
|
||||
RETURN_TYPES = ("MASK",) + ("IMAGE",) * MAX_ROUTES
|
||||
RETURN_NAMES = ("mask",) + tuple(f"route_{i + 1}" for i in range(MAX_ROUTES))
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"routes": ("INT", {"default": 2, "min": 1, "max": MAX_ROUTES}),
|
||||
},
|
||||
"hidden": {"unique_id": "UNIQUE_ID"},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, **kwargs):
|
||||
return float("nan") # always pause; never cached
|
||||
|
||||
def run(self, image, routes, unique_id):
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from . import gate_server
|
||||
|
||||
gate_bus.GateBus.arm(unique_id)
|
||||
gate_server.send_preview(unique_id, image, routes)
|
||||
try:
|
||||
chosen_1 = gate_bus.GateBus.wait(unique_id)
|
||||
except gate_bus.GateCancelled:
|
||||
import comfy.model_management as mm
|
||||
raise mm.InterruptProcessingException()
|
||||
|
||||
mask = mask_from_stash(gate_bus.GateBus.pop_mask(unique_id), image)
|
||||
chosen = max(0, min(chosen_1 - 1, routes - 1))
|
||||
blocker = ExecutionBlocker(None)
|
||||
return (mask,) + route_tuple(chosen, image, blocker, MAX_ROUTES)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"ImageGate": ImageGate}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"ImageGate": "Image Gate (Manual Router)"}
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Blocking choice bus for the Image Gate node. Stdlib only — no comfy/torch."""
|
||||
import time
|
||||
|
||||
|
||||
class GateCancelled(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class GateBus:
|
||||
messages = {} # node_id(str) -> chosen int (1-based)
|
||||
masks = {} # node_id(str) -> PNG bytes
|
||||
cancelled = False
|
||||
|
||||
@classmethod
|
||||
def arm(cls, node_id):
|
||||
cls.messages.pop(str(node_id), None)
|
||||
cls.masks.pop(str(node_id), None)
|
||||
cls.cancelled = False
|
||||
|
||||
@classmethod
|
||||
def put(cls, node_id, message):
|
||||
if message == "__cancel__":
|
||||
cls.cancelled = True
|
||||
else:
|
||||
cls.messages[str(node_id)] = int(message)
|
||||
|
||||
@classmethod
|
||||
def wait(cls, node_id, period=0.1):
|
||||
sid = str(node_id)
|
||||
while sid not in cls.messages:
|
||||
if cls.cancelled:
|
||||
cls.cancelled = False
|
||||
raise GateCancelled()
|
||||
time.sleep(period)
|
||||
return cls.messages.pop(sid)
|
||||
|
||||
@classmethod
|
||||
def put_mask(cls, node_id, data):
|
||||
cls.masks[str(node_id)] = data
|
||||
|
||||
@classmethod
|
||||
def pop_mask(cls, node_id):
|
||||
return cls.masks.pop(str(node_id), None)
|
||||
@@ -0,0 +1,44 @@
|
||||
# gates/gate_server.py
|
||||
import base64
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from server import PromptServer
|
||||
|
||||
from .gate_bus import GateBus
|
||||
|
||||
routes = PromptServer.instance.routes
|
||||
|
||||
|
||||
def send_preview(node_id, image, n_routes):
|
||||
arr = (image[0].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||
buf = io.BytesIO()
|
||||
Image.fromarray(arr).save(buf, "PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
PromptServer.instance.send_sync(
|
||||
"datasete-gate-show",
|
||||
{"id": str(node_id), "image": b64, "routes": int(n_routes)},
|
||||
)
|
||||
|
||||
|
||||
@routes.post("/datasete_gate/choice")
|
||||
async def _choice(request):
|
||||
post = await request.post()
|
||||
GateBus.put(post.get("id"), post.get("message"))
|
||||
return web.json_response({})
|
||||
|
||||
|
||||
@routes.post("/datasete_gate/mask")
|
||||
async def _mask(request):
|
||||
reader = await request.multipart()
|
||||
node_id, data = None, None
|
||||
async for part in reader:
|
||||
if part.name == "id":
|
||||
node_id = await part.text()
|
||||
elif part.name == "mask":
|
||||
data = await part.read(decode=False)
|
||||
if node_id is not None:
|
||||
GateBus.put_mask(node_id, data)
|
||||
return web.json_response({})
|
||||
@@ -0,0 +1,37 @@
|
||||
# tests/test_gate.py
|
||||
import io
|
||||
import math
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from gates import gate
|
||||
|
||||
def test_route_tuple_places_image_at_chosen():
|
||||
B = object()
|
||||
t = gate.route_tuple(2, "IMG", B, max_routes=5)
|
||||
assert t == (B, B, "IMG", B, B)
|
||||
|
||||
def test_route_tuple_length_is_max():
|
||||
B = object()
|
||||
assert len(gate.route_tuple(0, "IMG", B, max_routes=10)) == 10
|
||||
|
||||
def test_mask_from_stash_none_is_zeros():
|
||||
img = torch.zeros((1, 6, 4, 3))
|
||||
m = gate.mask_from_stash(None, img)
|
||||
assert m.shape == (1, 6, 4) and float(m.max()) == 0.0
|
||||
|
||||
def test_mask_from_stash_decodes_png():
|
||||
buf = io.BytesIO(); Image.new("L", (4, 6), 255).save(buf, "PNG")
|
||||
img = torch.zeros((1, 6, 4, 3))
|
||||
m = gate.mask_from_stash(buf.getvalue(), img)
|
||||
assert m.shape == (1, 6, 4) and float(m.min()) > 0.99
|
||||
|
||||
def test_is_changed_always_nan():
|
||||
v = gate.ImageGate.IS_CHANGED(image=None, routes=2, unique_id="1")
|
||||
assert math.isnan(v)
|
||||
|
||||
def test_return_types_shape():
|
||||
assert gate.ImageGate.RETURN_TYPES[0] == "MASK"
|
||||
assert len(gate.ImageGate.RETURN_TYPES) == gate.MAX_ROUTES + 1
|
||||
assert all(t == "IMAGE" for t in gate.ImageGate.RETURN_TYPES[1:])
|
||||
@@ -0,0 +1,38 @@
|
||||
# tests/test_gate_bus.py
|
||||
import pytest
|
||||
from gates import gate_bus as gb
|
||||
|
||||
def test_put_and_wait_returns_choice():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "3")
|
||||
assert gb.GateBus.wait("7") == 3
|
||||
|
||||
def test_wait_consumes_message():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "2")
|
||||
gb.GateBus.wait("7")
|
||||
assert "7" not in gb.GateBus.messages
|
||||
|
||||
def test_cancel_raises_and_resets():
|
||||
gb.GateBus.arm("7")
|
||||
gb.GateBus.put("7", "__cancel__")
|
||||
with pytest.raises(gb.GateCancelled):
|
||||
gb.GateBus.wait("7")
|
||||
assert gb.GateBus.cancelled is False # reset after raising
|
||||
|
||||
def test_arm_clears_stale_state():
|
||||
gb.GateBus.put("1", "5")
|
||||
gb.GateBus.cancelled = True
|
||||
gb.GateBus.arm("1")
|
||||
assert "1" not in gb.GateBus.messages
|
||||
assert gb.GateBus.cancelled is False
|
||||
|
||||
def test_mask_stash_roundtrip():
|
||||
gb.GateBus.put_mask("9", b"PNGDATA")
|
||||
assert gb.GateBus.pop_mask("9") == b"PNGDATA"
|
||||
assert gb.GateBus.pop_mask("9") is None # popped
|
||||
|
||||
def test_arm_clears_mask():
|
||||
gb.GateBus.put_mask("9", b"x")
|
||||
gb.GateBus.arm("9")
|
||||
assert gb.GateBus.pop_mask("9") is None
|
||||
@@ -0,0 +1,565 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
|
||||
// Image Gate (Manual Router) — pauses a running prompt, shows the image with N
|
||||
// labeled route buttons + an Edit-mask + a Stop button, and routes the image down
|
||||
// the clicked output (others ExecutionBlocker-ed server-side). The Python node
|
||||
// blocks in run() on GateBus.wait(); this extension renders the preview that the
|
||||
// server pushes via the "datasete-gate-show" socket event and POSTs the choice.
|
||||
|
||||
const NODE = "ImageGate";
|
||||
const MAX_ROUTES = 10;
|
||||
const R = "/datasete_gate";
|
||||
|
||||
const MIN_IMG_H = 140; // preview image area clamps (scales with node width)
|
||||
const MAX_IMG_H = 600;
|
||||
const BTN_ROW_H = 78; // buttons area (route buttons wrap + actions)
|
||||
const MARGIN = 10; // ComfyUI DOM-widget inset, matches the pool node
|
||||
|
||||
// ---- routes widget + label store -------------------------------------------
|
||||
|
||||
function routesWidget(node) {
|
||||
return node.widgets?.find((w) => w.name === "routes");
|
||||
}
|
||||
|
||||
function getRouteCount(node) {
|
||||
let n = parseInt(routesWidget(node)?.value ?? 2, 10);
|
||||
if (isNaN(n)) n = 2;
|
||||
return Math.max(1, Math.min(MAX_ROUTES, n));
|
||||
}
|
||||
|
||||
// Labels live in node.properties (litegraph serializes properties for free, so
|
||||
// they survive reload without a fake serializing widget — route_labels is not a
|
||||
// backend input, so we must NOT push it into widgets_values).
|
||||
function labelStore(node) {
|
||||
if (!Array.isArray(node.properties.routeLabels)) node.properties.routeLabels = [];
|
||||
return node.properties.routeLabels;
|
||||
}
|
||||
|
||||
function labelFor(node, route) { // route is 1-based
|
||||
const v = labelStore(node)[route - 1];
|
||||
return (v != null && String(v).trim()) || String(route);
|
||||
}
|
||||
|
||||
function setRouteLabel(node, route, text) {
|
||||
labelStore(node)[route - 1] = text;
|
||||
applyOutputLabels(node);
|
||||
if (node._gateState && node._gateState !== "idle") render(node); // live-update
|
||||
node.setDirtyCanvas?.(true, true);
|
||||
}
|
||||
|
||||
// ---- dynamic route outputs --------------------------------------------------
|
||||
// Slot 0 is the always-visible `mask` output; slots 1..N are route_1..route_N.
|
||||
// We only ever add/remove from the TAIL so existing slot indices (and the
|
||||
// backend's index→RETURN_TYPES mapping) stay stable and connections are kept.
|
||||
|
||||
function applyOutputLabels(node) {
|
||||
for (let i = 1; i < node.outputs.length; i++) {
|
||||
node.outputs[i].label = labelFor(node, i);
|
||||
}
|
||||
}
|
||||
|
||||
function applyRouteCount(node, n) {
|
||||
if (!node.outputs || node.outputs.length === 0) return;
|
||||
let cur = node.outputs.length - 1; // current route outputs
|
||||
while (cur < n) { node.addOutput(`route_${cur + 1}`, "IMAGE"); cur++; }
|
||||
while (cur > n) { node.removeOutput(node.outputs.length - 1); cur--; }
|
||||
applyOutputLabels(node);
|
||||
node.setDirtyCanvas?.(true, true);
|
||||
}
|
||||
|
||||
// ---- server calls -----------------------------------------------------------
|
||||
|
||||
async function postChoice(node, message) {
|
||||
const fd = new FormData();
|
||||
fd.append("id", String(node.id));
|
||||
fd.append("message", String(message));
|
||||
await api.fetchApi(`${R}/choice`, { method: "POST", body: fd });
|
||||
}
|
||||
|
||||
async function postMask(node, blob) {
|
||||
const fd = new FormData();
|
||||
fd.append("id", String(node.id));
|
||||
fd.append("mask", blob, "mask.png");
|
||||
await api.fetchApi(`${R}/mask`, { method: "POST", body: fd });
|
||||
}
|
||||
|
||||
// ---- preview DOM widget + state machine -------------------------------------
|
||||
// States: "idle" (collapsed, before the first run), "paused" (waiting for a
|
||||
// route choice — route buttons shown), "resolved" (a route was picked — image +
|
||||
// mask kept, a "Run from here" re-queue button shown). The node never blanks
|
||||
// once a run has happened, so the previewed image and the sticky mask stay for
|
||||
// context and the painted mask is reused on the next run until cleared.
|
||||
|
||||
function computeImgH(node) {
|
||||
// image area scales with node WIDTH and the image's aspect ratio, so a wider
|
||||
// node shows a bigger preview (getMinHeight is polled each layout frame).
|
||||
const w = Math.max(120, (node.size?.[0] || 220) - 2 * MARGIN);
|
||||
const h = Math.round(w * (node._imgAspect || 1));
|
||||
return Math.max(MIN_IMG_H, Math.min(h, MAX_IMG_H));
|
||||
}
|
||||
|
||||
function previewHeight(node) {
|
||||
if (!node._gateState || node._gateState === "idle") return 0;
|
||||
return 2 * MARGIN + computeImgH(node) + BTN_ROW_H;
|
||||
}
|
||||
|
||||
function resizePreview(node) {
|
||||
// Fully remove the preview element from layout when idle — collapsing the
|
||||
// widget height to 0 isn't enough: the <img> would still paint below the node.
|
||||
const shown = node._gateState && node._gateState !== "idle";
|
||||
if (node._gate) node._gate.wrap.style.display = shown ? "flex" : "none";
|
||||
const w = node.size?.[0] || 220;
|
||||
node.setSize([w, node.computeSize()[1]]);
|
||||
node.setDirtyCanvas(true, true);
|
||||
}
|
||||
|
||||
function hasMask(node) { return !!node._stickyMask; }
|
||||
|
||||
function maskControls(node) {
|
||||
// Edit / Clear buttons + a small "mask retained" badge, shared by both states.
|
||||
const els = [];
|
||||
const edit = document.createElement("button");
|
||||
edit.className = "dgate-edit";
|
||||
edit.textContent = "🖌 Edit mask";
|
||||
edit.onclick = () => openMaskEditor(node);
|
||||
els.push(edit);
|
||||
if (hasMask(node)) {
|
||||
const clr = document.createElement("button");
|
||||
clr.className = "dgate-clear";
|
||||
clr.textContent = "✕ Clear mask";
|
||||
clr.onclick = () => clearMask(node);
|
||||
els.push(clr);
|
||||
}
|
||||
const badge = document.createElement("span");
|
||||
badge.className = "dgate-status";
|
||||
badge.textContent = hasMask(node) ? "🎭 mask retained" : "no mask";
|
||||
badge.style.opacity = hasMask(node) ? "0.9" : "0.45";
|
||||
els.push(badge);
|
||||
return els;
|
||||
}
|
||||
|
||||
function render(node) {
|
||||
const { btns } = node._gate;
|
||||
btns.innerHTML = "";
|
||||
const routes = node._gateRoutes || getRouteCount(node);
|
||||
|
||||
if (node._gateState === "paused") {
|
||||
for (let i = 1; i <= routes; i++) {
|
||||
const b = document.createElement("button");
|
||||
b.className = "dgate-route";
|
||||
b.textContent = labelFor(node, i);
|
||||
b.onclick = async () => {
|
||||
await postChoice(node, i);
|
||||
showResolved(node, labelFor(node, i));
|
||||
};
|
||||
btns.appendChild(b);
|
||||
}
|
||||
maskControls(node).forEach((el) => btns.appendChild(el));
|
||||
const stop = document.createElement("button");
|
||||
stop.className = "dgate-stop";
|
||||
stop.textContent = "■ Stop";
|
||||
stop.onclick = async () => {
|
||||
await postChoice(node, "__cancel__");
|
||||
showResolved(node, "stopped");
|
||||
};
|
||||
btns.appendChild(stop);
|
||||
} else if (node._gateState === "resolved") {
|
||||
const status = document.createElement("span");
|
||||
status.className = "dgate-status";
|
||||
status.textContent = `✓ routed to ${node._gateChoice ?? "?"}`;
|
||||
btns.appendChild(status);
|
||||
const run = document.createElement("button");
|
||||
run.className = "dgate-run";
|
||||
run.textContent = "▶ Run from here";
|
||||
run.onclick = () => queueFromHere(node);
|
||||
btns.appendChild(run);
|
||||
maskControls(node).forEach((el) => btns.appendChild(el));
|
||||
}
|
||||
updateMaskOverlay(node);
|
||||
}
|
||||
|
||||
function showPaused(node, b64, routes) {
|
||||
node._gateState = "paused";
|
||||
node._gateRoutes = Math.max(1, Math.min(MAX_ROUTES, parseInt(routes, 10) || getRouteCount(node)));
|
||||
node._previewB64 = b64;
|
||||
node._gate.img.src = `data:image/png;base64,${b64}`;
|
||||
// sticky mask: re-stash the last painted mask for THIS run before the user
|
||||
// picks a route. run() does arm()→clear, then send_preview→this event, then
|
||||
// blocks in wait(), so this POST always lands before the choice is made.
|
||||
if (node._stickyMask) {
|
||||
postMask(node, b64ToBlob(node._stickyMask, "image/png")).catch(() => {});
|
||||
}
|
||||
render(node);
|
||||
resizePreview(node);
|
||||
}
|
||||
|
||||
function showResolved(node, choiceLabel) {
|
||||
node._gateState = "resolved";
|
||||
node._gateChoice = choiceLabel;
|
||||
render(node);
|
||||
resizePreview(node);
|
||||
}
|
||||
|
||||
async function queueFromHere(node) {
|
||||
try {
|
||||
await app.queuePrompt(0, 1);
|
||||
} catch (e) {
|
||||
try { await app.queuePrompt(0); } catch (e2) { console.error("[dgate] queue failed", e2); }
|
||||
}
|
||||
}
|
||||
|
||||
async function clearMask(node) {
|
||||
node._stickyMask = null;
|
||||
node._stickyMaskOverlay = null;
|
||||
// zero the current run's stash: an empty mask part -> server stores b"" ->
|
||||
// mask_from_stash() treats it as falsy -> zeros.
|
||||
try { await postMask(node, new Blob([], { type: "image/png" })); } catch (e) { /* ignore */ }
|
||||
render(node);
|
||||
}
|
||||
|
||||
// ---- mask overlay (show the painted region over the preview, semi-transparent)
|
||||
// The sticky mask is grayscale (white = painted). Recolor it into an RGBA layer
|
||||
// where alpha = paint intensity and RGB = a highlight color, so unpainted areas
|
||||
// are fully transparent and only the painted region tints the image.
|
||||
|
||||
function maskToOverlay(b64) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const im = new Image();
|
||||
im.onload = () => {
|
||||
const c = document.createElement("canvas");
|
||||
c.width = im.naturalWidth || im.width;
|
||||
c.height = im.naturalHeight || im.height;
|
||||
const ctx = c.getContext("2d");
|
||||
ctx.drawImage(im, 0, 0);
|
||||
const d = ctx.getImageData(0, 0, c.width, c.height);
|
||||
const px = d.data;
|
||||
for (let i = 0; i < px.length; i += 4) {
|
||||
const v = px[i]; // grayscale luminance (R=G=B)
|
||||
px[i] = 255; px[i + 1] = 64; px[i + 2] = 64; // highlight = red
|
||||
px[i + 3] = v; // alpha = paint intensity
|
||||
}
|
||||
ctx.putImageData(d, 0, 0);
|
||||
resolve(c.toDataURL("image/png"));
|
||||
};
|
||||
im.onerror = reject;
|
||||
im.src = `data:image/png;base64,${b64}`;
|
||||
});
|
||||
}
|
||||
|
||||
async function setStickyMask(node, b64) {
|
||||
node._stickyMask = b64;
|
||||
try {
|
||||
node._stickyMaskOverlay = b64 ? await maskToOverlay(b64) : null;
|
||||
} catch (e) {
|
||||
node._stickyMaskOverlay = null;
|
||||
}
|
||||
updateMaskOverlay(node);
|
||||
}
|
||||
|
||||
function updateMaskOverlay(node) {
|
||||
const mi = node._gate?.maskImg;
|
||||
if (!mi) return;
|
||||
if (node._gateState && node._gateState !== "idle" && node._stickyMaskOverlay) {
|
||||
mi.src = node._stickyMaskOverlay;
|
||||
mi.style.display = "block";
|
||||
} else {
|
||||
mi.removeAttribute("src");
|
||||
mi.style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
// ---- mask editor (reuses ComfyUI MaskEditor, like the pool node) ------------
|
||||
// The preview arrives as base64 (no server file), so upload it to input/ first,
|
||||
// point the MaskEditor at it, then poll node.images for the saved clipspace ref.
|
||||
|
||||
function b64ToBlob(b64, type) {
|
||||
const bin = atob(b64);
|
||||
const arr = new Uint8Array(bin.length);
|
||||
for (let i = 0; i < bin.length; i++) arr[i] = bin.charCodeAt(i);
|
||||
return new Blob([arr], { type });
|
||||
}
|
||||
|
||||
function blobToImage(blob) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const img = new Image();
|
||||
img.onload = () => resolve(img);
|
||||
img.onerror = reject;
|
||||
img.src = URL.createObjectURL(blob);
|
||||
});
|
||||
}
|
||||
|
||||
function blobToB64(blob) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const fr = new FileReader();
|
||||
fr.onload = () => resolve(String(fr.result).split(",")[1] || "");
|
||||
fr.onerror = reject;
|
||||
fr.readAsDataURL(blob);
|
||||
});
|
||||
}
|
||||
|
||||
function comfyAppClass() {
|
||||
try { return app.constructor; } catch (e) { return null; }
|
||||
}
|
||||
|
||||
// MaskEditor registers the painted image as this node's output; clear those
|
||||
// stores so nothing repopulates node.imgs (we draw our own preview).
|
||||
function clearNodeOutputs(node) {
|
||||
try {
|
||||
for (const map of [app.nodeOutputs, app.nodePreviewImages]) {
|
||||
if (!map) continue;
|
||||
for (const k of Object.keys(map)) {
|
||||
if (k === String(node.id) || k.endsWith(`:${node.id}`)) delete map[k];
|
||||
}
|
||||
}
|
||||
} catch (e) { /* best effort */ }
|
||||
}
|
||||
|
||||
function cleanupMaskState(node) {
|
||||
if (node._maskPoll) { clearInterval(node._maskPoll); node._maskPoll = null; }
|
||||
node._maskActive = false;
|
||||
try {
|
||||
node.images = undefined;
|
||||
node.previewMediaType = undefined;
|
||||
} catch (e) { /* best effort */ }
|
||||
clearNodeOutputs(node);
|
||||
node.setDirtyCanvas?.(true, true);
|
||||
}
|
||||
|
||||
async function uploadPreview(node) {
|
||||
const blob = b64ToBlob(node._previewB64, "image/png");
|
||||
const fd = new FormData();
|
||||
fd.append("image", blob, `gate_${node.id}.png`);
|
||||
fd.append("subfolder", "datasete_gate");
|
||||
fd.append("type", "input");
|
||||
fd.append("overwrite", "true");
|
||||
const res = await api.fetchApi("/upload/image", { method: "POST", body: fd });
|
||||
const j = await res.json();
|
||||
return { filename: j.name, subfolder: j.subfolder || "datasete_gate", type: j.type || "input" };
|
||||
}
|
||||
|
||||
async function captureMask(node, ref) {
|
||||
try {
|
||||
const sub = ref.subfolder ?? "clipspace";
|
||||
const type = ref.type ?? "input";
|
||||
const url = `/view?filename=${encodeURIComponent(ref.filename)}&subfolder=${encodeURIComponent(sub)}&type=${encodeURIComponent(type)}&r=${Date.now()}`;
|
||||
const resp = await api.fetchApi(url);
|
||||
const blob = await resp.blob();
|
||||
const img = await blobToImage(blob);
|
||||
const c = document.createElement("canvas");
|
||||
c.width = img.naturalWidth || img.width;
|
||||
c.height = img.naturalHeight || img.height;
|
||||
const ctx = c.getContext("2d");
|
||||
ctx.drawImage(img, 0, 0);
|
||||
const d = ctx.getImageData(0, 0, c.width, c.height);
|
||||
const px = d.data;
|
||||
// MaskEditor stores the mask in the ALPHA channel; painted areas come through
|
||||
// as alpha 0, so invert (255 - a) into grayscale -> white = painted (MASK).
|
||||
for (let i = 0; i < px.length; i += 4) {
|
||||
const a = px[i + 3];
|
||||
px[i] = px[i + 1] = px[i + 2] = 255 - a;
|
||||
px[i + 3] = 255;
|
||||
}
|
||||
ctx.putImageData(d, 0, 0);
|
||||
const maskBlob = await new Promise((res) => c.toBlob(res, "image/png"));
|
||||
await postMask(node, maskBlob);
|
||||
// remember it so it auto-applies on the next run until the user clears it,
|
||||
// and build the colored overlay shown over the preview.
|
||||
try { await setStickyMask(node, await blobToB64(maskBlob)); } catch (e) { /* ignore */ }
|
||||
} catch (e) {
|
||||
console.error("[dgate] mask capture failed", e);
|
||||
} finally {
|
||||
cleanupMaskState(node);
|
||||
if (node._gateState && node._gateState !== "idle") render(node); // show badge
|
||||
}
|
||||
}
|
||||
|
||||
async function openMaskEditor(node) {
|
||||
if (!node._previewB64) return;
|
||||
cleanupMaskState(node);
|
||||
let ref;
|
||||
try {
|
||||
ref = await uploadPreview(node);
|
||||
} catch (e) {
|
||||
console.error("[dgate] preview upload failed", e);
|
||||
return;
|
||||
}
|
||||
|
||||
node.images = [ref];
|
||||
node.previewMediaType = "image";
|
||||
node.imageIndex = 0;
|
||||
node._maskActive = true;
|
||||
|
||||
const Comfy = comfyAppClass();
|
||||
try { if (Comfy) Comfy.clipspace_return_node = node; } catch (e) { /* ignore */ }
|
||||
|
||||
// No save callback in frontend 1.45 — poll for the editor writing clipspace.
|
||||
let waited = 0;
|
||||
node._maskPoll = setInterval(() => {
|
||||
waited += 300;
|
||||
const r = node.images && node.images[0];
|
||||
if (node._maskActive && r && r.subfolder === "clipspace") {
|
||||
clearInterval(node._maskPoll); node._maskPoll = null;
|
||||
captureMask(node, r);
|
||||
} else if (waited > 10 * 60 * 1000) {
|
||||
cleanupMaskState(node);
|
||||
}
|
||||
}, 300);
|
||||
|
||||
try { app.canvas?.selectNode?.(node); } catch (e) { /* ignore */ }
|
||||
const cmd = app.extensionManager?.command;
|
||||
if (cmd?.execute) {
|
||||
cmd.execute("Comfy.MaskEditor.OpenMaskEditor");
|
||||
} else if (Comfy?.open_maskeditor) {
|
||||
Comfy.open_maskeditor();
|
||||
} else {
|
||||
console.error("[dgate] no MaskEditor entry point found");
|
||||
cleanupMaskState(node);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- styles + node setup ----------------------------------------------------
|
||||
|
||||
function injectStyles() {
|
||||
if (document.getElementById("dgate-styles")) return;
|
||||
const css = `
|
||||
.dgate-wrap { display:flex; flex-direction:column; gap:6px; box-sizing:border-box;
|
||||
height:100%; min-height:0; }
|
||||
.dgate-imgbox { position:relative; flex:1 1 auto; min-height:0; width:100%;
|
||||
background:rgba(0,0,0,0.25); border-radius:4px; overflow:hidden; }
|
||||
.dgate-img { position:absolute; inset:0; width:100%; height:100%; object-fit:contain;
|
||||
display:block; }
|
||||
.dgate-mask { position:absolute; inset:0; width:100%; height:100%; object-fit:contain;
|
||||
opacity:0.5; pointer-events:none; }
|
||||
.dgate-btns { display:flex; flex-wrap:wrap; gap:6px; align-items:center; flex:0 0 auto; }
|
||||
.dgate-btns button { font-size:12px; padding:3px 10px; cursor:pointer; border-radius:3px;
|
||||
border:1px solid #555; color:#fff; }
|
||||
.dgate-route { background:rgba(40,90,140,0.9); }
|
||||
.dgate-route:hover { background:rgba(60,120,180,0.95); }
|
||||
.dgate-edit { background:rgba(40,40,40,0.9); }
|
||||
.dgate-clear { background:rgba(90,60,30,0.9); }
|
||||
.dgate-run { background:rgba(40,130,70,0.95); }
|
||||
.dgate-stop { background:rgba(160,40,40,0.9); margin-left:auto; }
|
||||
.dgate-status { font-size:11px; opacity:0.8; padding:0 4px; align-self:center; }
|
||||
`;
|
||||
const style = document.createElement("style");
|
||||
style.id = "dgate-styles";
|
||||
style.textContent = css;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
function setupGateNode(node) {
|
||||
injectStyles();
|
||||
|
||||
// Never let the MaskEditor's source image render as an output preview on us —
|
||||
// we draw the preview ourselves in the DOM widget below.
|
||||
try {
|
||||
Object.defineProperty(node, "imgs", {
|
||||
configurable: true,
|
||||
get() { return undefined; },
|
||||
set() { /* suppress */ },
|
||||
});
|
||||
} catch (e) { /* ignore */ }
|
||||
|
||||
const wrap = document.createElement("div");
|
||||
wrap.className = "dgate-wrap";
|
||||
|
||||
// image + mask overlay share a container so both letterbox identically and
|
||||
// stay pixel-aligned (object-fit:contain on same-size, same-aspect layers).
|
||||
const imgbox = document.createElement("div");
|
||||
imgbox.className = "dgate-imgbox";
|
||||
const img = document.createElement("img");
|
||||
img.className = "dgate-img";
|
||||
// capture the image aspect so the preview area scales with the node width
|
||||
img.onload = () => {
|
||||
const w = img.naturalWidth || 1;
|
||||
const h = img.naturalHeight || 1;
|
||||
node._imgAspect = h / w;
|
||||
resizePreview(node);
|
||||
};
|
||||
const maskImg = document.createElement("img");
|
||||
maskImg.className = "dgate-mask";
|
||||
maskImg.style.display = "none";
|
||||
imgbox.appendChild(img);
|
||||
imgbox.appendChild(maskImg);
|
||||
|
||||
const btns = document.createElement("div");
|
||||
btns.className = "dgate-btns";
|
||||
wrap.appendChild(imgbox);
|
||||
wrap.appendChild(btns);
|
||||
node._gate = { wrap, imgbox, img, maskImg, btns };
|
||||
|
||||
node._previewWidget = node.addDOMWidget("gate_preview", "div", wrap, {
|
||||
serialize: false,
|
||||
getMinHeight: () => previewHeight(node),
|
||||
});
|
||||
|
||||
// sync visible route outputs to the routes widget, now and on change
|
||||
applyRouteCount(node, getRouteCount(node));
|
||||
const rw = routesWidget(node);
|
||||
if (rw) {
|
||||
const prev = rw.callback;
|
||||
rw.callback = function () {
|
||||
const r = prev?.apply(this, arguments);
|
||||
applyRouteCount(node, getRouteCount(node));
|
||||
return r;
|
||||
};
|
||||
}
|
||||
|
||||
node._gateState = "idle";
|
||||
resizePreview(node);
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "datasete.gates.imagegate",
|
||||
|
||||
// one global socket listener: route the server's pause event to the node
|
||||
setup() {
|
||||
api.addEventListener("datasete-gate-show", (e) => {
|
||||
const d = e.detail || {};
|
||||
const node = app.graph?.getNodeById?.(parseInt(d.id, 10));
|
||||
if (!node || node.type !== NODE) return;
|
||||
showPaused(node, d.image, d.routes);
|
||||
});
|
||||
},
|
||||
|
||||
async beforeRegisterNodeDef(nodeType, nodeData) {
|
||||
if (nodeData.name !== NODE) return;
|
||||
|
||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
const r = onNodeCreated?.apply(this, arguments);
|
||||
setupGateNode(this);
|
||||
return r;
|
||||
};
|
||||
|
||||
// loaded workflows restore the routes widget + properties after create —
|
||||
// re-sync output count/labels to match.
|
||||
const onConfigure = nodeType.prototype.onConfigure;
|
||||
nodeType.prototype.onConfigure = function () {
|
||||
const r = onConfigure?.apply(this, arguments);
|
||||
if (this.outputs) {
|
||||
applyRouteCount(this, getRouteCount(this));
|
||||
}
|
||||
return r;
|
||||
};
|
||||
|
||||
// per-route "Rename…" entries (editable labels, persisted in properties)
|
||||
const getExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
|
||||
nodeType.prototype.getExtraMenuOptions = function (canvas, options) {
|
||||
const r = getExtraMenuOptions?.apply(this, arguments);
|
||||
const node = this;
|
||||
const routes = getRouteCount(node);
|
||||
for (let i = 1; i <= routes; i++) {
|
||||
options.push({
|
||||
content: `Rename route ${i} (“${labelFor(node, i)}”)…`,
|
||||
callback: () => {
|
||||
const text = prompt(`Label for route ${i}:`, labelFor(node, i));
|
||||
if (text != null) setRouteLabel(node, i, text);
|
||||
},
|
||||
});
|
||||
}
|
||||
return r;
|
||||
};
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user