From 40be11cd953a8d6ed39d12c01398810e6a193549 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 13:00:30 +0200 Subject: [PATCH] feat: imaging tensor loaders + change hash Co-Authored-By: Claude Opus 4.8 --- gates/imaging.py | 30 ++++++++++++++++++++++++++++++ tests/test_imaging.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 gates/imaging.py create mode 100644 tests/test_imaging.py diff --git a/gates/imaging.py b/gates/imaging.py new file mode 100644 index 0000000..3339ef6 --- /dev/null +++ b/gates/imaging.py @@ -0,0 +1,30 @@ +"""Tensor/imaging helpers (torch + PIL). No comfy imports.""" +import hashlib +import numpy as np +import torch +from PIL import Image, ImageOps + + +def load_image_tensor(path): + img = Image.open(path) + img = ImageOps.exif_transpose(img).convert("RGB") + arr = np.array(img, dtype=np.float32) / 255.0 + return torch.from_numpy(arr).unsqueeze(0) # [1,H,W,3] + + +def load_mask_tensor(path, h, w): + if not path: + return torch.zeros((1, h, w), dtype=torch.float32) + m = Image.open(path).convert("L") + arr = np.array(m, dtype=np.float32) / 255.0 + return torch.from_numpy(arr).unsqueeze(0) # [1,H,W] + + +def empty_outputs(): + return (torch.zeros((1, 1, 1, 3), dtype=torch.float32), + torch.zeros((1, 1, 1), dtype=torch.float32)) + + +def change_hash(pool_id, index, mtimes): + key = f"{pool_id}|{index}|" + "|".join(f"{t:.3f}" for t in mtimes) + return hashlib.sha256(key.encode()).hexdigest() diff --git a/tests/test_imaging.py b/tests/test_imaging.py new file mode 100644 index 0000000..ff03173 --- /dev/null +++ b/tests/test_imaging.py @@ -0,0 +1,43 @@ +import numpy as np, torch +from PIL import Image +from gates import imaging + + +def _png(tmp_path, name, color, size=(4, 6)): # size = (w, h) + p = tmp_path / name + Image.new("RGB", size, color).save(p) + return str(p) + + +def test_load_image_tensor_shape_and_range(tmp_path): + t = imaging.load_image_tensor(_png(tmp_path, "a.png", (255, 0, 0))) + assert t.shape == (1, 6, 4, 3) # [B,H,W,C] + assert t.dtype == torch.float32 + assert 0.0 <= float(t.min()) and float(t.max()) <= 1.0 + assert float(t[0, 0, 0, 0]) > 0.99 # red channel + + +def test_load_mask_none_is_zeros(): + m = imaging.load_mask_tensor(None, h=6, w=4) + assert m.shape == (1, 6, 4) + assert float(m.max()) == 0.0 + + +def test_load_mask_from_file(tmp_path): + p = tmp_path / "m.png" + Image.new("L", (4, 6), 255).save(p) + m = imaging.load_mask_tensor(str(p), h=6, w=4) + assert m.shape == (1, 6, 4) + assert float(m.min()) > 0.99 + + +def test_empty_image_is_1x1_black(): + img, mask = imaging.empty_outputs() + assert img.shape == (1, 1, 1, 3) and float(img.max()) == 0.0 + assert mask.shape == (1, 1, 1) + + +def test_change_hash_changes_with_mtime(): + h1 = imaging.change_hash("p", 0, [1000.0]) + h2 = imaging.change_hash("p", 0, [1001.0]) + assert h1 != h2