diff --git a/gates/handlers.py b/gates/handlers.py new file mode 100644 index 0000000..9e9e1b7 --- /dev/null +++ b/gates/handlers.py @@ -0,0 +1,26 @@ +"""Pure request handlers — no aiohttp. Each returns the updated manifest dict.""" +from . import pool + + +def handle_add(base, pool_id, data, ext, ts=0): + return pool.add_image(base, pool_id, data, ts=ts) + + +def handle_remove(base, pool_id, index): + return pool.remove_slot(base, pool_id, index) + + +def handle_active(base, pool_id, index): + return pool.set_active(base, pool_id, index) + + +def handle_label(base, pool_id, index, label): + return pool.set_label(base, pool_id, index, label) + + +def handle_list(base, pool_id): + return pool.read_manifest(base, pool_id) + + +def handle_set_mask(base, pool_id, index, mask_png_bytes): + return pool.set_mask(base, pool_id, index, mask_png_bytes) # Task 12 diff --git a/gates/routes.py b/gates/routes.py index 91f44aa..e22f8ab 100644 --- a/gates/routes.py +++ b/gates/routes.py @@ -1 +1,51 @@ -# gates/routes.py — stub (filled in Task 10) +"""aiohttp routes for the Image Pool node. Imported only inside ComfyUI.""" +import json +from aiohttp import web +from server import PromptServer +from . import handlers +from .gates_compat import grid_pool_base + +routes = PromptServer.instance.routes + + +def _base(): + return grid_pool_base() + + +@routes.post("/grid_pool/add") +async def _add(request): + reader = await request.multipart() + pool_id, ts, data = "default", 0, None + async for part in reader: + if part.name == "pool_id": + pool_id = (await part.text()) + elif part.name == "ts": + ts = int(await part.text()) + elif part.name == "image": + data = await part.read(decode=False) + m = handlers.handle_add(_base(), pool_id, data, "png", ts=ts) + return web.json_response(m) + + +@routes.post("/grid_pool/remove") +async def _remove(request): + body = await request.json() + return web.json_response(handlers.handle_remove(_base(), body["pool_id"], int(body["index"]))) + + +@routes.post("/grid_pool/active") +async def _active(request): + body = await request.json() + return web.json_response(handlers.handle_active(_base(), body["pool_id"], int(body["index"]))) + + +@routes.post("/grid_pool/label") +async def _label(request): + body = await request.json() + return web.json_response(handlers.handle_label(_base(), body["pool_id"], int(body["index"]), body["label"])) + + +@routes.get("/grid_pool/list") +async def _list(request): + pool_id = request.query.get("pool_id", "default") + return web.json_response(handlers.handle_list(_base(), pool_id)) diff --git a/tests/test_routes_logic.py b/tests/test_routes_logic.py new file mode 100644 index 0000000..1a48e17 --- /dev/null +++ b/tests/test_routes_logic.py @@ -0,0 +1,23 @@ +import io +from PIL import Image +from gates import handlers + + +def _png_bytes(color=(1, 2, 3)): + b = io.BytesIO(); Image.new("RGB", (4, 4), color).save(b, "PNG"); return b.getvalue() + + +def test_handle_add_then_list(tmp_path): + base = str(tmp_path) + m = handlers.handle_add(base, "p1", _png_bytes(), "png", ts=5) + assert len(m["slots"]) == 1 + assert handlers.handle_list(base, "p1")["slots"][0]["image"] == "img_0001.png" + + +def test_handle_active_label_remove(tmp_path): + base = str(tmp_path) + handlers.handle_add(base, "p1", _png_bytes(), "png", ts=1) + handlers.handle_add(base, "p1", _png_bytes(), "png", ts=2) + assert handlers.handle_active(base, "p1", 1)["active"] == 1 + assert handlers.handle_label(base, "p1", 0, "hi")["slots"][0]["label"] == "hi" + assert len(handlers.handle_remove(base, "p1", 0)["slots"]) == 1