From ea3438567acf29f84fc6e5c03a70e8e253e7a20a Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 17:42:04 +0200 Subject: [PATCH] feat: gate mask_from_stash (paint or zeros) Co-Authored-By: Claude Opus 4.8 --- gates/gate.py | 9 +++++++++ tests/test_gate.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/gates/gate.py b/gates/gate.py index d7bcfc7..6ac5e29 100644 --- a/gates/gate.py +++ b/gates/gate.py @@ -13,3 +13,12 @@ MAX_ROUTES = 10 def route_tuple(chosen, image, blocker, max_routes=MAX_ROUTES): return tuple(image if i == chosen else blocker for i in range(max_routes)) + + +def mask_from_stash(data, image): + b, h, w = image.shape[0], image.shape[1], image.shape[2] + if not data: + return torch.zeros((b, h, w), dtype=torch.float32) + m = Image.open(io.BytesIO(data)).convert("L") + arr = np.array(m, dtype=np.float32) / 255.0 + return torch.from_numpy(arr).unsqueeze(0) diff --git a/tests/test_gate.py b/tests/test_gate.py index 2af7a61..b0c3bdc 100644 --- a/tests/test_gate.py +++ b/tests/test_gate.py @@ -1,4 +1,9 @@ # tests/test_gate.py +import io + +import torch +from PIL import Image + from gates import gate def test_route_tuple_places_image_at_chosen(): @@ -9,3 +14,14 @@ def test_route_tuple_places_image_at_chosen(): def test_route_tuple_length_is_max(): B = object() assert len(gate.route_tuple(0, "IMG", B, max_routes=10)) == 10 + +def test_mask_from_stash_none_is_zeros(): + img = torch.zeros((1, 6, 4, 3)) + m = gate.mask_from_stash(None, img) + assert m.shape == (1, 6, 4) and float(m.max()) == 0.0 + +def test_mask_from_stash_decodes_png(): + buf = io.BytesIO(); Image.new("L", (4, 6), 255).save(buf, "PNG") + img = torch.zeros((1, 6, 4, 3)) + m = gate.mask_from_stash(buf.getvalue(), img) + assert m.shape == (1, 6, 4) and float(m.min()) > 0.99