From 8e8eb317f77435ea6640ac2f7c3e6d0d45c719d8 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 17:43:53 +0200 Subject: [PATCH] feat: gate server routes + preview + register ImageGate Co-Authored-By: Claude Opus 4.8 --- __init__.py | 7 +++++-- gates/gate_server.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 gates/gate_server.py diff --git a/__init__.py b/__init__.py index 5c5c5da..be141f7 100644 --- a/__init__.py +++ b/__init__.py @@ -12,10 +12,13 @@ if __package__: NODE_DISPLAY_NAME_MAPPINGS as _POOL_NAMES from .gates.loader import NODE_CLASS_MAPPINGS as _LOADER_NODES, \ NODE_DISPLAY_NAME_MAPPINGS as _LOADER_NAMES + from .gates.gate import NODE_CLASS_MAPPINGS as _GATE_NODES, \ + NODE_DISPLAY_NAME_MAPPINGS as _GATE_NAMES from .gates import routes # noqa: F401 (registers aiohttp routes on import) + from .gates import gate_server # noqa: F401 (registers /datasete_gate/* routes) - NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES} - NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES} + NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES, **_GATE_NODES} + NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES, **_GATE_NAMES} else: # pragma: no cover - exercised only under pytest collection NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {} diff --git a/gates/gate_server.py b/gates/gate_server.py new file mode 100644 index 0000000..43b020a --- /dev/null +++ b/gates/gate_server.py @@ -0,0 +1,44 @@ +# gates/gate_server.py +import base64 +import io + +import numpy as np +from aiohttp import web +from PIL import Image +from server import PromptServer + +from .gate_bus import GateBus + +routes = PromptServer.instance.routes + + +def send_preview(node_id, image, n_routes): + arr = (image[0].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") + buf = io.BytesIO() + Image.fromarray(arr).save(buf, "PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + PromptServer.instance.send_sync( + "datasete-gate-show", + {"id": str(node_id), "image": b64, "routes": int(n_routes)}, + ) + + +@routes.post("/datasete_gate/choice") +async def _choice(request): + post = await request.post() + GateBus.put(post.get("id"), post.get("message")) + return web.json_response({}) + + +@routes.post("/datasete_gate/mask") +async def _mask(request): + reader = await request.multipart() + node_id, data = None, None + async for part in reader: + if part.name == "id": + node_id = await part.text() + elif part.name == "mask": + data = await part.read(decode=False) + if node_id is not None: + GateBus.put_mask(node_id, data) + return web.json_response({})