diff --git a/gates/bucket_node.py b/gates/bucket_node.py index 9e62962..5d57f21 100644 --- a/gates/bucket_node.py +++ b/gates/bucket_node.py @@ -33,7 +33,7 @@ def fit_mask(mask, W, H): out = [] for i in range(b): arr = (mask[i].cpu().numpy() * 255.0).clip(0, 255).astype("uint8") - pil = _resize_crop_pil(Image.fromarray(arr, mode="L"), new_w, new_h, left, top, W, H) + pil = _resize_crop_pil(Image.fromarray(arr), new_w, new_h, left, top, W, H) out.append(torch.from_numpy(np.array(pil, dtype=np.float32) / 255.0)) return torch.stack(out, 0)