feat: gate mask_from_stash (paint or zeros)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -13,3 +13,12 @@ MAX_ROUTES = 10
|
|||||||
|
|
||||||
def route_tuple(chosen, image, blocker, max_routes=MAX_ROUTES):
|
def route_tuple(chosen, image, blocker, max_routes=MAX_ROUTES):
|
||||||
return tuple(image if i == chosen else blocker for i in range(max_routes))
|
return tuple(image if i == chosen else blocker for i in range(max_routes))
|
||||||
|
|
||||||
|
|
||||||
|
def mask_from_stash(data, image):
|
||||||
|
b, h, w = image.shape[0], image.shape[1], image.shape[2]
|
||||||
|
if not data:
|
||||||
|
return torch.zeros((b, h, w), dtype=torch.float32)
|
||||||
|
m = Image.open(io.BytesIO(data)).convert("L")
|
||||||
|
arr = np.array(m, dtype=np.float32) / 255.0
|
||||||
|
return torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
# tests/test_gate.py
|
# tests/test_gate.py
|
||||||
|
import io
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from gates import gate
|
from gates import gate
|
||||||
|
|
||||||
def test_route_tuple_places_image_at_chosen():
|
def test_route_tuple_places_image_at_chosen():
|
||||||
@@ -9,3 +14,14 @@ def test_route_tuple_places_image_at_chosen():
|
|||||||
def test_route_tuple_length_is_max():
|
def test_route_tuple_length_is_max():
|
||||||
B = object()
|
B = object()
|
||||||
assert len(gate.route_tuple(0, "IMG", B, max_routes=10)) == 10
|
assert len(gate.route_tuple(0, "IMG", B, max_routes=10)) == 10
|
||||||
|
|
||||||
|
def test_mask_from_stash_none_is_zeros():
|
||||||
|
img = torch.zeros((1, 6, 4, 3))
|
||||||
|
m = gate.mask_from_stash(None, img)
|
||||||
|
assert m.shape == (1, 6, 4) and float(m.max()) == 0.0
|
||||||
|
|
||||||
|
def test_mask_from_stash_decodes_png():
|
||||||
|
buf = io.BytesIO(); Image.new("L", (4, 6), 255).save(buf, "PNG")
|
||||||
|
img = torch.zeros((1, 6, 4, 3))
|
||||||
|
m = gate.mask_from_stash(buf.getvalue(), img)
|
||||||
|
assert m.shape == (1, 6, 4) and float(m.min()) > 0.99
|
||||||
|
|||||||
Reference in New Issue
Block a user