Merge feat/bucket-resize: Bucket Resize (Klein 9B) node
Auto-snaps images onto ÷64 ≤1.64MP training buckets (cover + center-crop, Lanczos), transforms an optional mask identically, outputs width/height/label. Pure bucket math tested against KLEIN_BUCKET_SIZES.md. 99 tests pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
+4
-2
@@ -18,14 +18,16 @@ if __package__:
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS as _TEXT_NAMES
|
NODE_DISPLAY_NAME_MAPPINGS as _TEXT_NAMES
|
||||||
from .gates.profile_node import NODE_CLASS_MAPPINGS as _PROF_NODES, \
|
from .gates.profile_node import NODE_CLASS_MAPPINGS as _PROF_NODES, \
|
||||||
NODE_DISPLAY_NAME_MAPPINGS as _PROF_NAMES
|
NODE_DISPLAY_NAME_MAPPINGS as _PROF_NAMES
|
||||||
|
from .gates.bucket_node import NODE_CLASS_MAPPINGS as _BUCKET_NODES, \
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS as _BUCKET_NAMES
|
||||||
from .gates import routes # noqa: F401 (registers aiohttp routes on import)
|
from .gates import routes # noqa: F401 (registers aiohttp routes on import)
|
||||||
from .gates import gate_server # noqa: F401 (registers /datasete_gate/* + text routes)
|
from .gates import gate_server # noqa: F401 (registers /datasete_gate/* + text routes)
|
||||||
from .gates import profiles_routes # noqa: F401 (registers /grid_pool/profiles/*)
|
from .gates import profiles_routes # noqa: F401 (registers /grid_pool/profiles/*)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES, **_GATE_NODES,
|
NODE_CLASS_MAPPINGS = {**_POOL_NODES, **_LOADER_NODES, **_GATE_NODES,
|
||||||
**_TEXT_NODES, **_PROF_NODES}
|
**_TEXT_NODES, **_PROF_NODES, **_BUCKET_NODES}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES, **_GATE_NAMES,
|
NODE_DISPLAY_NAME_MAPPINGS = {**_POOL_NAMES, **_LOADER_NAMES, **_GATE_NAMES,
|
||||||
**_TEXT_NAMES, **_PROF_NAMES}
|
**_TEXT_NAMES, **_PROF_NAMES, **_BUCKET_NAMES}
|
||||||
else: # pragma: no cover - exercised only under pytest collection
|
else: # pragma: no cover - exercised only under pytest collection
|
||||||
NODE_CLASS_MAPPINGS = {}
|
NODE_CLASS_MAPPINGS = {}
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""BucketResize node: cover-crop an image (and optional mask) onto a Klein
|
||||||
|
training bucket. Pure compute (torch + PIL); no comfy imports in run()."""
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from . import buckets
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_crop_pil(pil, new_w, new_h, left, top, W, H):
|
||||||
|
pil = pil.resize((new_w, new_h), Image.LANCZOS)
|
||||||
|
return pil.crop((left, top, left + W, top + H))
|
||||||
|
|
||||||
|
|
||||||
|
def fit_image(image, W, H):
|
||||||
|
"""image [B,H,W,3] -> [B,H,W,3] at (W,H) using the first image's geometry."""
|
||||||
|
b, ih, iw = image.shape[0], image.shape[1], image.shape[2]
|
||||||
|
new_w, new_h, left, top, scale = buckets.cover_crop_params(iw, ih, W, H)
|
||||||
|
out = []
|
||||||
|
for i in range(b):
|
||||||
|
arr = (image[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||||
|
pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H)
|
||||||
|
out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0))
|
||||||
|
return torch.stack(out, 0), scale
|
||||||
|
|
||||||
|
|
||||||
|
def fit_mask(mask, W, H):
|
||||||
|
b, ih, iw = mask.shape[0], mask.shape[1], mask.shape[2]
|
||||||
|
new_w, new_h, left, top, _ = buckets.cover_crop_params(iw, ih, W, H)
|
||||||
|
out = []
|
||||||
|
for i in range(b):
|
||||||
|
arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8")
|
||||||
|
pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H)
|
||||||
|
out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0))
|
||||||
|
return torch.stack(out, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class BucketResize:
|
||||||
|
CATEGORY = "Datasete Gates"
|
||||||
|
FUNCTION = "run"
|
||||||
|
RETURN_TYPES = ("IMAGE", "MASK", "INT", "INT", "STRING")
|
||||||
|
RETURN_NAMES = ("image", "mask", "width", "height", "label")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"resolution": ("INT", {"default": 1280, "min": 64, "max": 8192}),
|
||||||
|
"divisible": ("INT", {"default": 64, "min": 8, "max": 256}),
|
||||||
|
"max_upscale": ("FLOAT", {"default": 1.5, "min": 1.0, "max": 8.0, "step": 0.1}),
|
||||||
|
},
|
||||||
|
"optional": {"mask": ("MASK",)},
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, image, resolution=1280, divisible=64, max_upscale=1.5, mask=None):
|
||||||
|
ih, iw = int(image.shape[1]), int(image.shape[2])
|
||||||
|
W, H = buckets.pick_bucket(iw, ih, resolution, divisible)
|
||||||
|
out_img, scale = fit_image(image, W, H)
|
||||||
|
if scale > max_upscale:
|
||||||
|
print(f"[BucketResize] cover scale {scale:.2f}x exceeds max_upscale "
|
||||||
|
f"{max_upscale} for {iw}x{ih} -> {W}x{H}")
|
||||||
|
out_mask = fit_mask(mask, W, H) if mask is not None \
|
||||||
|
else torch.zeros((out_img.shape[0], H, W), dtype=torch.float32)
|
||||||
|
return (out_img, out_mask, W, H, f"{W}x{H}")
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {"BucketResize": BucketResize}
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {"BucketResize": "Bucket Resize (Klein 9B)"}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""Pure bucket math for KLEIN_BUCKET_SIZES.md. Stdlib only."""
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def pick_bucket(iw, ih, resolution=1280, divisible=64):
|
||||||
|
"""Choose the on-grid bucket (W,H), area <= resolution^2, nearest to the
|
||||||
|
image aspect (log distance; tie-break larger area)."""
|
||||||
|
budget = resolution * resolution
|
||||||
|
target = iw / ih
|
||||||
|
best = None
|
||||||
|
w = divisible
|
||||||
|
w_max = budget // divisible
|
||||||
|
while w <= w_max:
|
||||||
|
h = (budget // w // divisible) * divisible # largest on-grid h within budget
|
||||||
|
if h >= divisible:
|
||||||
|
err = abs(math.log(w / h) - math.log(target))
|
||||||
|
cand = (err, -(w * h), w, h) # min err, then max area
|
||||||
|
if best is None or cand < best:
|
||||||
|
best = cand
|
||||||
|
w += divisible
|
||||||
|
return best[2], best[3]
|
||||||
|
|
||||||
|
|
||||||
|
def cover_crop_params(iw, ih, W, H):
|
||||||
|
"""Cover-scale + centered crop to land (iw,ih) exactly on (W,H)."""
|
||||||
|
scale = max(W / iw, H / ih)
|
||||||
|
new_w = max(W, round(iw * scale))
|
||||||
|
new_h = max(H, round(ih * scale))
|
||||||
|
left = (new_w - W) // 2
|
||||||
|
top = (new_h - H) // 2
|
||||||
|
return new_w, new_h, left, top, scale
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from gates import bucket_node as bn
|
||||||
|
|
||||||
|
|
||||||
|
def test_square_to_1280():
|
||||||
|
out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 1000, 3)))
|
||||||
|
assert (w, h) == (1280, 1280)
|
||||||
|
assert out.shape == (1, 1280, 1280, 3)
|
||||||
|
assert m.shape == (1, 1280, 1280) and float(m.max()) == 0.0 # no mask -> zeros
|
||||||
|
assert label == "1280x1280"
|
||||||
|
|
||||||
|
|
||||||
|
def test_landscape_bucket_shapes():
|
||||||
|
# tensor [B,H,W,3] with H=1000,W=2000 -> aspect 2.0 -> 1792x896
|
||||||
|
out, m, w, h, label = bn.BucketResize().run(image=torch.rand((1, 1000, 2000, 3)))
|
||||||
|
assert (w, h) == (1792, 896)
|
||||||
|
assert out.shape == (1, 896, 1792, 3)
|
||||||
|
assert label == "1792x896"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_resized_and_aligned():
|
||||||
|
out, m, w, h, _ = bn.BucketResize().run(
|
||||||
|
image=torch.rand((1, 1000, 1000, 3)), mask=torch.ones((1, 1000, 1000)))
|
||||||
|
assert m.shape == (1, 1280, 1280) and float(m.min()) > 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_outputs_are_on_grid():
|
||||||
|
out, m, w, h, _ = bn.BucketResize().run(
|
||||||
|
image=torch.rand((1, 777, 1333, 3)), resolution=1280, divisible=64)
|
||||||
|
assert w % 64 == 0 and h % 64 == 0
|
||||||
|
assert out.shape[1] == h and out.shape[2] == w
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
from gates import buckets
|
||||||
|
|
||||||
|
# (iw, ih) -> expected (W, H) from KLEIN_BUCKET_SIZES.md, budget 1280, ÷64
|
||||||
|
CASES = [
|
||||||
|
(1000, 1000, 1280, 1280), # square
|
||||||
|
(1000, 2000, 896, 1792), # a=0.50 portrait
|
||||||
|
(1000, 1730, 960, 1664), # a≈0.58
|
||||||
|
(1000, 1100, 1216, 1344), # a≈0.90 -> portrait-leaning
|
||||||
|
(2000, 1000, 1792, 896), # a=2.00 landscape
|
||||||
|
(1500, 1000, 1536, 1024), # a=1.50
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pick_bucket_matches_table():
|
||||||
|
for iw, ih, W, H in CASES:
|
||||||
|
assert buckets.pick_bucket(iw, ih, 1280, 64) == (W, H)
|
||||||
|
|
||||||
|
|
||||||
|
def test_buckets_are_on_grid_and_within_budget():
|
||||||
|
for iw, ih, *_ in CASES:
|
||||||
|
W, H = buckets.pick_bucket(iw, ih, 1280, 64)
|
||||||
|
assert W % 64 == 0 and H % 64 == 0
|
||||||
|
assert W * H <= 1280 * 1280
|
||||||
|
|
||||||
|
|
||||||
|
def test_square_is_exactly_1280():
|
||||||
|
assert buckets.pick_bucket(512, 512, 1280, 64) == (1280, 1280)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cover_crop_exact_aspect_no_crop():
|
||||||
|
# a=2.0 image onto 1792x896 bucket -> scale 0.896, no crop
|
||||||
|
new_w, new_h, left, top, scale = buckets.cover_crop_params(2000, 1000, 1792, 896)
|
||||||
|
assert (new_w, new_h) == (1792, 896)
|
||||||
|
assert (left, top) == (0, 0)
|
||||||
|
assert round(scale, 3) == 0.896
|
||||||
|
|
||||||
|
|
||||||
|
def test_cover_crop_square_into_landscape_crops_height():
|
||||||
|
new_w, new_h, left, top, scale = buckets.cover_crop_params(1000, 1000, 1792, 896)
|
||||||
|
assert new_w == 1792 and new_h >= 896
|
||||||
|
assert left == 0 and top == (new_h - 896) // 2 # centered vertical crop
|
||||||
|
assert scale > 1.0 # upscaled to cover width
|
||||||
|
|
||||||
|
|
||||||
|
def test_cover_crop_upscale_square():
|
||||||
|
*_, scale = buckets.cover_crop_params(1000, 1000, 1280, 1280)
|
||||||
|
assert round(scale, 2) == 1.28
|
||||||
Reference in New Issue
Block a user