feat: BucketResize node (cover-crop onto Klein buckets)
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,72 @@
|
||||
"""BucketResize node: cover-crop an image (and optional mask) onto a Klein
|
||||
training bucket. Pure compute (torch + PIL); no comfy imports in run()."""
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from . import buckets
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
|
||||
def _resize_crop_pil(pil, new_w, new_h, left, top, W, H):
|
||||
pil = pil.resize((new_w, new_h), Image.LANCZOS)
|
||||
return pil.crop((left, top, left + W, top + H))
|
||||
|
||||
|
||||
def fit_image(image, W, H):
|
||||
"""image [B,H,W,3] -> [B,H,W,3] at (W,H) using the first image's geometry."""
|
||||
b, ih, iw = image.shape[0], image.shape[1], image.shape[2]
|
||||
new_w, new_h, left, top, scale = buckets.cover_crop_params(iw, ih, W, H)
|
||||
out = []
|
||||
for i in range(b):
|
||||
arr = (image[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||
pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H)
|
||||
out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0))
|
||||
return torch.stack(out, 0), scale
|
||||
|
||||
|
||||
def fit_mask(mask, W, H):
|
||||
b, ih, iw = mask.shape[0], mask.shape[1], mask.shape[2]
|
||||
new_w, new_h, left, top, _ = buckets.cover_crop_params(iw, ih, W, H)
|
||||
out = []
|
||||
for i in range(b):
|
||||
arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||
pil = _resize_crop_pil(Image.fromarray(arr, mode="L"), new_w, new_h, left, top, W, H)
|
||||
out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0))
|
||||
return torch.stack(out, 0)
|
||||
|
||||
|
||||
class BucketResize:
|
||||
CATEGORY = "Datasete Gates"
|
||||
FUNCTION = "run"
|
||||
RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING")
|
||||
RETURN_NAMES = ("image", "mask", "width", "height", "label")
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"resolution": ("INT", {"default": 1280, "min": 64, "max": 8192}),
|
||||
"divisible": ("INT", {"default": 64, "min": 8, "max": 256}),
|
||||
"max_upscale": ("FLOAT", {"default": 1.5, "min": 1.0, "max": 8.0, "step": 0.1}),
|
||||
},
|
||||
"optional": {"mask": ("MASK",)},
|
||||
}
|
||||
|
||||
def run(self, image, resolution=1280, divisible=64, max_upscale=1.5, mask=None):
|
||||
ih, iw = int(image.shape[1]), int(image.shape[2])
|
||||
W, H = buckets.pick_bucket(iw, ih, resolution, divisible)
|
||||
out_img, scale = fit_image(image, W, H)
|
||||
if scale > max_upscale:
|
||||
print(f"[BucketResize] cover scale {scale:.2f}x exceeds max_upscale "
|
||||
f"{max_upscale} for {iw}x{ih} -> {W}x{H}")
|
||||
out_mask = fit_mask(mask, W, H) if mask is not None \
|
||||
else torch.zeros((out_img.shape[0], H, W), dtype=torch.float32)
|
||||
return (out_img, out_mask, W, H, f"{W}x{H}")
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {"BucketResize": BucketResize}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {"BucketResize": "Bucket Resize (Klein 9B)"}
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from gates import bucket_node as bn
|
||||
|
||||
|
||||
def test_square_to_1280():
|
||||
out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 1000, 3)))
|
||||
assert (w, h) == (1280, 1280)
|
||||
assert out.shape == (1, 1280, 1280, 3)
|
||||
assert m.shape == (1, 1280, 1280) and float(m.max()) == 0.0 # no mask -> zeros
|
||||
assert label == "1280x1280"
|
||||
|
||||
|
||||
def test_landscape_bucket_shapes():
|
||||
# tensor [B,H,W,3] with H=1000,W=2000 -> aspect 2.0 -> 1792x896
|
||||
out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 2000, 3)))
|
||||
assert (w, h) == (1792, 896)
|
||||
assert out.shape == (1, 896, 1792, 3)
|
||||
assert label == "1792x896"
|
||||
|
||||
|
||||
def test_mask_resized_and_aligned():
|
||||
out, m, w, h, _ = bn.BucketResize().run(
|
||||
image=torch.rand((1, 1000, 1000, 3)), mask=torch.ones((1, 1000, 1000)))
|
||||
assert m.shape == (1, 1280, 1280) and float(m.min()) > 0.9
|
||||
|
||||
|
||||
def test_outputs_are_on_grid():
|
||||
out, m, w, h, _ = bn.BucketResize().run(
|
||||
image=torch.rand((1, 777, 1333, 3)), resolution=1280, divisible=64)
|
||||
assert w % 64 == 0 and h % 64 == 0
|
||||
assert out.shape[1] == h and out.shape[2] == w
|
||||
Reference in New Issue
Block a user