From cdd742c950d498e29182abe7f51cd506e4945bf9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 22:47:46 +0200 Subject: [PATCH 1/5] feat: bucket selection matching Klein 9B table Co-Authored-By: Claude Opus 4.8 --- gates/buckets.py | 31 +++++++++++++++++++++++++++++++ tests/test_buckets.py | 27 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 gates/buckets.py create mode 100644 tests/test_buckets.py diff --git a/gates/buckets.py b/gates/buckets.py new file mode 100644 index 0000000..8f1540a --- /dev/null +++ b/gates/buckets.py @@ -0,0 +1,31 @@ +"""Pure bucket math for KLEIN_BUCKET_SIZES.md. Stdlib only.""" +import math + + +def pick_bucket(iw, ih, resolution=1280, divisible=64): + """Choose the on-grid bucket (W,H), area <= resolution^2, nearest to the + image aspect (log distance; tie-break larger area).""" + budget = resolution * resolution + target = iw / ih + best = None + w = divisible + w_max = budget // divisible + while w <= w_max: + h = (budget // w // divisible) * divisible # largest on-grid h within budget + if h >= divisible: + err = abs(math.log(w / h) - math.log(target)) + cand = (err, -(w * h), w, h) # min err, then max area + if best is None or cand < best: + best = cand + w += divisible + return best[2], best[3] + + +def cover_crop_params(iw, ih, W, H): + """Cover-scale + centered crop to land (iw,ih) exactly on (W,H).""" + scale = max(W / iw, H / ih) + new_w = max(W, round(iw * scale)) + new_h = max(H, round(ih * scale)) + left = (new_w - W) // 2 + top = (new_h - H) // 2 + return new_w, new_h, left, top, scale diff --git a/tests/test_buckets.py b/tests/test_buckets.py new file mode 100644 index 0000000..47bf7eb --- /dev/null +++ b/tests/test_buckets.py @@ -0,0 +1,27 @@ +from gates import buckets + +# (iw, ih) -> expected (W, H) from KLEIN_BUCKET_SIZES.md, budget 1280, ÷64 +CASES = [ + (1000, 1000, 1280, 1280), # square + (1000, 2000, 896, 1792), # a=0.50 portrait + (1000, 1730, 960, 1664), # a≈0.58 + (1000, 1100, 1216, 1344), # a≈0.90 -> portrait-leaning + (2000, 1000, 1792, 896), # a=2.00 landscape + (1500, 1000, 1536, 1024), # a=1.50 +] + + +def test_pick_bucket_matches_table(): + for iw, ih, W, H in CASES: + assert buckets.pick_bucket(iw, ih, 1280, 64) == (W, H) + + +def test_buckets_are_on_grid_and_within_budget(): + for iw, ih, *_ in CASES: + W, H = buckets.pick_bucket(iw, ih, 1280, 64) + assert W % 64 == 0 and H % 64 == 0 + assert W * H <= 1280 * 1280 + + +def test_square_is_exactly_1280(): + assert buckets.pick_bucket(512, 512, 1280, 64) == (1280, 1280) From 0413e25571db4b0c1402acc3e1ea173b45506616 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 22:48:13 +0200 Subject: [PATCH 2/5] test: bucket cover_crop_params geometry Co-Authored-By: Claude Opus 4.8 --- tests/test_buckets.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_buckets.py b/tests/test_buckets.py index 47bf7eb..f4777c9 100644 --- a/tests/test_buckets.py +++ b/tests/test_buckets.py @@ -25,3 +25,23 @@ def test_buckets_are_on_grid_and_within_budget(): def test_square_is_exactly_1280(): assert buckets.pick_bucket(512, 512, 1280, 64) == (1280, 1280) + + +def test_cover_crop_exact_aspect_no_crop(): + # a=2.0 image onto 1792x896 bucket -> scale 0.896, no crop + new_w, new_h, left, top, scale = buckets.cover_crop_params(2000, 1000, 1792, 896) + assert (new_w, new_h) == (1792, 896) + assert (left, top) == (0, 0) + assert round(scale, 3) == 0.896 + + +def test_cover_crop_square_into_landscape_crops_height(): + new_w, new_h, left, top, scale = buckets.cover_crop_params(1000, 1000, 1792, 896) + assert new_w == 1792 and new_h >= 896 + assert left == 0 and top == (new_h - 896) // 2 # centered vertical crop + assert scale > 1.0 # upscaled to cover width + + +def test_cover_crop_upscale_square(): + *_, scale = buckets.cover_crop_params(1000, 1000, 1280, 1280) + assert round(scale, 2) == 1.28 From 7f90b6878f64d6600de3501a0b5bc8882d907a82 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 22:49:01 +0200 Subject: [PATCH 3/5] feat: BucketResize node (cover-crop onto Klein buckets) Co-Authored-By: Claude Opus 4.8 --- gates/bucket_node.py | 72 +++++++++++++++++++++++++++++++++++++++ tests/test_bucket_node.py | 31 +++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 gates/bucket_node.py create mode 100644 tests/test_bucket_node.py diff --git a/gates/bucket_node.py b/gates/bucket_node.py new file mode 100644 index 0000000..9e62962 --- /dev/null +++ b/gates/bucket_node.py @@ -0,0 +1,72 @@ +"""BucketResize node: cover-crop an image (and optional mask) onto a Klein +training bucket. Pure compute (torch + PIL); no comfy imports in run().""" +import numpy as np +import torch +from PIL import Image + +from . import buckets + +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + + +def _resize_crop_pil(pil, new_w, new_h, left, top, W, H): + pil = pil.resize((new_w, new_h), Image.LANCZOS) + return pil.crop((left, top, left + W, top + H)) + + +def fit_image(image, W, H): + """image [B,H,W,3] -> [B,H,W,3] at (W,H) using the first image's geometry.""" + b, ih, iw = image.shape[0], image.shape[1], image.shape[2] + new_w, new_h, left, top, scale = buckets.cover_crop_params(iw, ih, W, H) + out = [] + for i in range(b): + arr = (image[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") + pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H) + out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) + return torch.stack(out, 0), scale + + +def fit_mask(mask, W, H): + b, ih, iw = mask.shape[0], mask.shape[1], mask.shape[2] + new_w, new_h, left, top, _ = buckets.cover_crop_params(iw, ih, W, H) + out = [] + for i in range(b): + arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") + pil = _resize_crop_pil(Image.fromarray(arr, mode="L"), new_w, new_h, left, top, W, H) + out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) + return torch.stack(out, 0) + + +class BucketResize: + CATEGORY = "Datasete Gates" + FUNCTION = "run" + RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING") + RETURN_NAMES = ("image", "mask", "width", "height", "label") + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + "resolution": ("INT", {"default": 1280, "min": 64, "max": 8192}), + "divisible": ("INT", {"default": 64, "min": 8, "max": 256}), + "max_upscale": ("FLOAT", {"default": 1.5, "min": 1.0, "max": 8.0, "step": 0.1}), + }, + "optional": {"mask": ("MASK",)}, + } + + def run(self, image, resolution=1280, divisible=64, max_upscale=1.5, mask=None): + ih, iw = int(image.shape[1]), int(image.shape[2]) + W, H = buckets.pick_bucket(iw, ih, resolution, divisible) + out_img, scale = fit_image(image, W, H) + if scale > max_upscale: + print(f"[BucketResize] cover scale {scale:.2f}x exceeds max_upscale " + f"{max_upscale} for {iw}x{ih} -> {W}x{H}") + out_mask = fit_mask(mask, W, H) if mask is not None \ + else torch.zeros((out_img.shape[0], H, W), dtype=torch.float32) + return (out_img, out_mask, W, H, f"{W}x{H}") + + +NODE_CLASS_MAPPINGS = {"BucketResize": BucketResize} +NODE_DISPLAY_NAME_MAPPINGS = {"BucketResize": "Bucket Resize (Klein 9B)"} diff --git a/tests/test_bucket_node.py b/tests/test_bucket_node.py new file mode 100644 index 0000000..1966283 --- /dev/null +++ b/tests/test_bucket_node.py @@ -0,0 +1,31 @@ +import torch +from gates import bucket_node as bn + + +def test_square_to_1280(): + out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 1000, 3))) + assert (w, h) == (1280, 1280) + assert out.shape == (1, 1280, 1280, 3) + assert m.shape == (1, 1280, 1280) and float(m.max()) == 0.0 # no mask -> zeros + assert label == "1280x1280" + + +def test_landscape_bucket_shapes(): + # tensor [B,H,W,3] with H=1000,W=2000 -> aspect 2.0 -> 1792x896 + out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 2000, 3))) + assert (w, h) == (1792, 896) + assert out.shape == (1, 896, 1792, 3) + assert label == "1792x896" + + +def test_mask_resized_and_aligned(): + out, m, w, h, _ = bn.BucketResize().run( + image=torch.rand((1, 1000, 1000, 3)), mask=torch.ones((1, 1000, 1000))) + assert m.shape == (1, 1280, 1280) and float(m.min()) > 0.9 + + +def test_outputs_are_on_grid(): + out, m, w, h, _ = bn.BucketResize().run( + image=torch.rand((1, 777, 1333, 3)), resolution=1280, divisible=64) + assert w % 64 == 0 and h % 64 == 0 + assert out.shape[1] == h and out.shape[2] == w From 969463a4e9b9de42b1db75f6d22b591b2dfe150d Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 22:52:45 +0200 Subject: [PATCH 4/5] fix: drop deprecated Pillow mode= arg in fit_mask uint8 2D arrays infer "L" automatically; silences Pillow 13 deprecation. Co-Authored-By: Claude Opus 4.8 --- gates/bucket_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gates/bucket_node.py b/gates/bucket_node.py index 9e62962..5d57f21 100644 --- a/gates/bucket_node.py +++ b/gates/bucket_node.py @@ -33,7 +33,7 @@ def fit_mask(mask, W, H): out = [] for i in range(b): arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") - pil = _resize_crop_pil(Image.fromarray(arr, mode="L"), new_w, new_h, left, top, W, H) + pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H) out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) return torch.stack(out, 0) From 037cbf27dbe54cc8a8d273a2de452e0eaf59d825 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 21 Jun 2026 22:52:45 +0200 Subject: [PATCH 5/5] feat: register BucketResize Co-Authored-By: Claude Opus 4.8 --- __init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/__init__.py b/__init__.py index 8273bf2..3e230e9 100644 --- a/__init__.py +++ b/__init__.py @@ -18,14 +18,16 @@ if __package__: NODE_DISPLAY_NAME_MAPPINGS as _TEXT_NAMES from .gates.profile_node import NODE_CLASS_MAPPINGS as _PROF_NODES, \ NODE_DISPLAY_NAME_MAPPINGS as _PROF_NAMES + from .gates.bucket_node import NODE_CLASS_MAPPINGS as _BUCKET_NODES, \ + NODE_DISPLAY_NAME_MAPPINGS as _BUCKET_NAMES from .gates import routes # noqa: F401 (registers aiohttp routes on import) from .gates import gate_server # noqa: F401 (registers /datasete_gate/* + text routes) from .gates import profiles_routes # noqa: F401 (registers /grid_pool/profiles/*) NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES, **_GATE_NODES, - **_TEXT_NODES, **_PROF_NODES} + **_TEXT_NODES, **_PROF_NODES, **_BUCKET_NODES} NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES, **_GATE_NAMES, - **_TEXT_NAMES, **_PROF_NAMES} + **_TEXT_NAMES, **_PROF_NAMES, **_BUCKET_NAMES} else: # pragma: no cover - exercised only under pytest collection NODE_CLASS_MAPPINGS = {} NODE_DISPLAY_NAME_MAPPINGS = {}