diff --git a/gates/node.py b/gates/node.py index c59bd68..9dc1bba 100644 --- a/gates/node.py +++ b/gates/node.py @@ -26,6 +26,11 @@ class GridImagePool: "index": ("INT", {"default": -1, "min": -1, "max": 9999}), "pool_id": ("STRING", {"default": "default"}), }, + # optional companion input: a Pool Profile node feeds the profile id + # here; when connected it overrides pool_id (see `effective` below). + "optional": { + "profile": ("POOL_PROFILE",), + }, } @staticmethod @@ -35,13 +40,14 @@ class GridImagePool: idx = pool.resolve_slot(m, index) return base, m, idx - def run(self, index, pool_id="default"): - base, m, idx = self._resolve(index, pool_id) + def run(self, index, pool_id="default", profile=None): + effective = profile or pool_id + base, m, idx = self._resolve(index, effective) if idx < 0: img, mask = imaging.empty_outputs() return (img, mask, 0, 0, "") slot = m["slots"][idx] - d = pool.pool_dir(base, pool_id) + d = pool.pool_dir(base, effective) img = imaging.load_image_tensor(str(d / slot["image"])) h, w = int(img.shape[1]), int(img.shape[2]) mask_name = slot.get("mask") @@ -49,19 +55,20 @@ class GridImagePool: return (img, mask, idx, len(m["slots"]), slot.get("label", "")) @classmethod - def IS_CHANGED(cls, index, pool_id="default", **kwargs): - base, m, idx = cls._resolve(index, pool_id) + def IS_CHANGED(cls, index, pool_id="default", profile=None, **kwargs): + effective = profile or pool_id + base, m, idx = cls._resolve(index, effective) if idx < 0: - return imaging.change_hash(pool_id, -1, []) + return imaging.change_hash(effective, -1, []) slot = m["slots"][idx] - d = pool.pool_dir(base, pool_id) + d = pool.pool_dir(base, effective) mtimes = [] for key in ("image", "mask"): name = slot.get(key) p = d / name if name else None mtimes.append(os.path.getmtime(p) if p and p.exists() else 0.0) # include active so manual selection changes invalidate cache - return imaging.change_hash(pool_id, f"{idx}:{m.get('active')}", mtimes) + return imaging.change_hash(effective, f"{idx}:{m.get('active')}", mtimes) NODE_CLASS_MAPPINGS = {"GridImagePool": GridImagePool} diff --git a/tests/test_node.py b/tests/test_node.py index e500df3..5ae2cdc 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -54,3 +54,17 @@ def test_is_changed_differs_after_active_change(tmp_path, monkeypatch): pool.set_active(base, "p1", 1) h2 = node.GridImagePool.IS_CHANGED(index=-1, pool_id="p1") assert h1 != h2 + + +def test_profile_input_overrides_pool_id(tmp_path, monkeypatch): + base = str(tmp_path / "grid_pool") + monkeypatch.setattr(node, "_grid_pool_base", lambda: base) + import io + from PIL import Image + from gates import pool + buf = io.BytesIO(); Image.new("RGB", (4, 6), (255, 0, 0)).save(buf, "PNG") + pool.add_image(base, "prof1", buf.getvalue(), ts=1) # images under the PROFILE id + n = node.GridImagePool() + # pool_id is "default" (empty) but profile points at prof1 + img, mask, idx, count, label = n.run(index=-1, pool_id="default", profile="prof1") + assert count == 1 and idx == 0