From 1db94dd57d754ffed705febae9f79d2cb5bcc008 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 13:01:39 +0200 Subject: [PATCH] feat: GridImagePool node (image/mask/index/count/label + IS_CHANGED) Co-Authored-By: Claude Opus 4.8 --- gates/gates_compat.py | 11 ++++++++ gates/node.py | 61 ++++++++++++++++++++++++++++++++++++++++++- tests/test_node.py | 56 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 gates/gates_compat.py create mode 100644 tests/test_node.py diff --git a/gates/gates_compat.py b/gates/gates_compat.py new file mode 100644 index 0000000..91a358c --- /dev/null +++ b/gates/gates_compat.py @@ -0,0 +1,11 @@ +"""Isolates the ComfyUI dependency so node.py stays unit-testable. + +node.py imports ``grid_pool_base`` from here; tests monkeypatch +``node._grid_pool_base`` so ``folder_paths`` is never needed. +""" +import os + + +def grid_pool_base(): + import folder_paths # imported lazily; only available inside ComfyUI + return os.path.join(folder_paths.get_input_directory(), "grid_pool") diff --git a/gates/node.py b/gates/node.py index d3d66cc..047428a 100644 --- a/gates/node.py +++ b/gates/node.py @@ -1,3 +1,62 @@ -# gates/node.py — stub (filled in Task 9) +"""GridImagePool — the Image Pool (Grid) ComfyUI node.""" +import os +from .gates_compat import grid_pool_base as _grid_pool_base +from . import pool, imaging + NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {} + + +class GridImagePool: + CATEGORY = "Datasete Gates" + FUNCTION = "run" + RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING") + RETURN_NAMES = ("image", "mask", "index", "count", "label") + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "index": ("INT", {"default": -1, "min": -1, "max": 9999}), + }, + "hidden": {"pool_id": "POOL_ID"}, + } + + @staticmethod + def _resolve(index, pool_id): + base = _grid_pool_base() + m = pool.read_manifest(base, pool_id) + idx = pool.resolve_slot(m, index) + return base, m, idx + + def run(self, index, pool_id="default"): + base, m, idx = self._resolve(index, pool_id) + if idx < 0: + img, mask = imaging.empty_outputs() + return (img, mask, 0, 0, "") + slot = m["slots"][idx] + d = pool.pool_dir(base, pool_id) + img = imaging.load_image_tensor(str(d / slot["image"])) + h, w = int(img.shape[1]), int(img.shape[2]) + mask_name = slot.get("mask") + mask = imaging.load_mask_tensor(str(d / mask_name) if mask_name else None, h, w) + return (img, mask, idx, len(m["slots"]), slot.get("label", "")) + + @classmethod + def IS_CHANGED(cls, index, pool_id="default", **kwargs): + base, m, idx = cls._resolve(index, pool_id) + if idx < 0: + return imaging.change_hash(pool_id, -1, []) + slot = m["slots"][idx] + d = pool.pool_dir(base, pool_id) + mtimes = [] + for key in ("image", "mask"): + name = slot.get(key) + p = d / name if name else None + mtimes.append(os.path.getmtime(p) if p and p.exists() else 0.0) + # include active so manual selection changes invalidate cache + return imaging.change_hash(pool_id, f"{idx}:{m.get('active')}", mtimes) + + +NODE_CLASS_MAPPINGS = {"GridImagePool": GridImagePool} +NODE_DISPLAY_NAME_MAPPINGS = {"GridImagePool": "Image Pool (Grid)"} diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 0000000..e500df3 --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,56 @@ +import io +import numpy as np, torch +from PIL import Image +from gates import node, pool + + +def _seed_pool(tmp_path, monkeypatch): + base = str(tmp_path / "grid_pool") + monkeypatch.setattr(node, "_grid_pool_base", lambda: base) + return base + + +def _add_png(base, pid, name_bytes_color, ts): + # write a real PNG via pool.add_image + buf = io.BytesIO(); Image.new("RGB", (4, 6), name_bytes_color).save(buf, "PNG") + return pool.add_image(base, pid, buf.getvalue(), ts=ts) + + +def test_execute_empty_pool_returns_blank(tmp_path, monkeypatch): + _seed_pool(tmp_path, monkeypatch) + n = node.GridImagePool() + img, mask, idx, count, label = n.run(index=-1, pool_id="p1") + assert img.shape == (1, 1, 1, 3) + assert count == 0 and idx == 0 and label == "" + + +def test_execute_selects_active(tmp_path, monkeypatch): + base = _seed_pool(tmp_path, monkeypatch) + _add_png(base, "p1", (255, 0, 0), 1) + _add_png(base, "p1", (0, 255, 0), 2) + pool.set_active(base, "p1", 1) + pool.set_label(base, "p1", 1, "green") + n = node.GridImagePool() + img, mask, idx, count, label = n.run(index=-1, pool_id="p1") + assert img.shape == (1, 6, 4, 3) + assert idx == 1 and count == 2 and label == "green" + assert float(img[0, 0, 0, 1]) > 0.99 # green channel + assert float(mask.max()) == 0.0 # no mask yet + + +def test_execute_forced_index_clamps(tmp_path, monkeypatch): + base = _seed_pool(tmp_path, monkeypatch) + _add_png(base, "p1", (255, 0, 0), 1) + n = node.GridImagePool() + _, _, idx, count, _ = n.run(index=9, pool_id="p1") + assert idx == 0 and count == 1 + + +def test_is_changed_differs_after_active_change(tmp_path, monkeypatch): + base = _seed_pool(tmp_path, monkeypatch) + _add_png(base, "p1", (255, 0, 0), 1) + _add_png(base, "p1", (0, 255, 0), 2) + h1 = node.GridImagePool.IS_CHANGED(index=-1, pool_id="p1") + pool.set_active(base, "p1", 1) + h2 = node.GridImagePool.IS_CHANGED(index=-1, pool_id="p1") + assert h1 != h2