diff --git a/gates/bucket_node.py b/gates/bucket_node.py new file mode 100644 index 0000000..9e62962 --- /dev/null +++ b/gates/bucket_node.py @@ -0,0 +1,72 @@ +"""BucketResize node: cover-crop an image (and optional mask) onto a Klein +training bucket. Pure compute (torch + PIL); no comfy imports in run().""" +import numpy as np +import torch +from PIL import Image + +from . import buckets + +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + + +def _resize_crop_pil(pil, new_w, new_h, left, top, W, H): + pil = pil.resize((new_w, new_h), Image.LANCZOS) + return pil.crop((left, top, left + W, top + H)) + + +def fit_image(image, W, H): + """image [B,H,W,3] -> [B,H,W,3] at (W,H) using the first image's geometry.""" + b, ih, iw = image.shape[0], image.shape[1], image.shape[2] + new_w, new_h, left, top, scale = buckets.cover_crop_params(iw, ih, W, H) + out = [] + for i in range(b): + arr = (image[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") + pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H) + out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) + return torch.stack(out, 0), scale + + +def fit_mask(mask, W, H): + b, ih, iw = mask.shape[0], mask.shape[1], mask.shape[2] + new_w, new_h, left, top, _ = buckets.cover_crop_params(iw, ih, W, H) + out = [] + for i in range(b): + arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") + pil = _resize_crop_pil(Image.fromarray(arr, mode="L"), new_w, new_h, left, top, W, H) + out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) + return torch.stack(out, 0) + + +class BucketResize: + CATEGORY = "Datasete Gates" + FUNCTION = "run" + RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING") + RETURN_NAMES = ("image", "mask", "width", "height", "label") + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "resolution": ("INT", {"default": 1280, "min": 64, "max": 8192}), + "divisible": ("INT", {"default": 64, "min": 8, "max": 256}), + "max_upscale": ("FLOAT", {"default": 1.5, "min": 1.0, "max": 8.0, "step": 0.1}), + }, + "optional": {"mask": ("MASK",)}, + } + + def run(self, image, resolution=1280, divisible=64, max_upscale=1.5, mask=None): + ih, iw = int(image.shape[1]), int(image.shape[2]) + W, H = buckets.pick_bucket(iw, ih, resolution, divisible) + out_img, scale = fit_image(image, W, H) + if scale > max_upscale: + print(f"[BucketResize] cover scale {scale:.2f}x exceeds max_upscale " + f"{max_upscale} for {iw}x{ih} -> {W}x{H}") + out_mask = fit_mask(mask, W, H) if mask is not None \ + else torch.zeros((out_img.shape[0], H, W), dtype=torch.float32) + return (out_img, out_mask, W, H, f"{W}x{H}") + + +NODE_CLASS_MAPPINGS = {"BucketResize": BucketResize} +NODE_DISPLAY_NAME_MAPPINGS = {"BucketResize": "Bucket Resize (Klein 9B)"} diff --git a/tests/test_bucket_node.py b/tests/test_bucket_node.py new file mode 100644 index 0000000..1966283 --- /dev/null +++ b/tests/test_bucket_node.py @@ -0,0 +1,31 @@ +import torch +from gates import bucket_node as bn + + +def test_square_to_1280(): + out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 1000, 3))) + assert (w, h) == (1280, 1280) + assert out.shape == (1, 1280, 1280, 3) + assert m.shape == (1, 1280, 1280) and float(m.max()) == 0.0 # no mask -> zeros + assert label == "1280x1280" + + +def test_landscape_bucket_shapes(): + # tensor [B,H,W,3] with H=1000,W=2000 -> aspect 2.0 -> 1792x896 + out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 2000, 3))) + assert (w, h) == (1792, 896) + assert out.shape == (1, 896, 1792, 3) + assert label == "1792x896" + + +def test_mask_resized_and_aligned(): + out, m, w, h, _ = bn.BucketResize().run( + image=torch.rand((1, 1000, 1000, 3)), mask=torch.ones((1, 1000, 1000))) + assert m.shape == (1, 1280, 1280) and float(m.min()) > 0.9 + + +def test_outputs_are_on_grid(): + out, m, w, h, _ = bn.BucketResize().run( + image=torch.rand((1, 777, 1333, 3)), resolution=1280, divisible=64) + assert w % 64 == 0 and h % 64 == 0 + assert out.shape[1] == h and out.shape[2] == w