From f3cabcad90aa71be15e6f278c192a58132f741fb Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 12:52:03 +0200 Subject: [PATCH] experiment: crop-to-mask mode on feature extractor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of squishing the full frame to a square, optionally crops a square region around the mask bounding box (union across all frames) before resizing. Preserves aspect ratio of the subject and gives the model a focused, undistorted view. New optional inputs on SelVA Feature Extractor: - crop_to_mask (BOOLEAN, default false) — enable the crop path - crop_margin (FLOAT 0–1, default 0.1) — padding around the bbox as a fraction of the bounding box side _compute_mask_bbox: resizes mask to frame resolution, takes union over all mask frames, expands to square + margin, shifts into frame bounds to preserve square shape before clamping. Falls back to center square crop if mask is empty. Bbox is computed once from the original-resolution mask and reused for both CLIP (384px) and sync (224px) frame sequences. Combine with mask_clip/mask_sync for full background suppression on top of the crop. Cache hash includes crop_to_mask and crop_margin when mask is connected. Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_feature_extractor.py | 85 ++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index ada6d8a..690100f 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -35,6 +35,56 @@ def _resize_frames(frames, size): return x.clamp(0.0, 1.0) # [N, C, H, W] +def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1): + """ + Compute a square bounding box around the union of all mask frames. + + mask: [M, H', W'] float [0,1] + Returns (y0, x0, y1, x1) in pixel coords relative to (frame_h, frame_w). + Falls back to a center square crop if the mask is empty. + """ + if mask.shape[1] != frame_h or mask.shape[2] != frame_w: + m = F.interpolate( + mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact" + ).squeeze(1) + else: + m = mask.float() + + union = (m > 0.5).max(dim=0).values # [H, W] bool + + if not union.any(): + # Empty mask — center square crop + side = min(frame_h, frame_w) + cy, cx = frame_h // 2, frame_w // 2 + y0 = max(0, cy - side // 2) + x0 = max(0, cx - side // 2) + return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side) + + ys = union.any(dim=1).nonzero(as_tuple=True)[0] + xs = union.any(dim=0).nonzero(as_tuple=True)[0] + y0, y1 = int(ys[0]), int(ys[-1]) + 1 + x0, x1 = int(xs[0]), int(xs[-1]) + 1 + + side = max(y1 - y0, x1 - x0) + pad = int(side * margin) + side += 2 * pad + + cy = (y0 + y1) // 2 + cx = (x0 + x1) // 2 + y0n = cy - side // 2 + x0n = cx - side // 2 + y1n = y0n + side + x1n = x0n + side + + # Shift into frame bounds to preserve square shape + if y0n < 0: y1n -= y0n; y0n = 0 + if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h + if x0n < 0: x1n -= x0n; x0n = 0 + if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w + + return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n)) + + def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0): """ Apply a ComfyUI MASK to resized frames. @@ -69,7 +119,8 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0): def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None, - mask_strength=1.0, mask_clip=True, mask_sync=True): + mask_strength=1.0, mask_clip=True, mask_sync=True, + crop_to_mask=False, crop_margin=0.1): h = hashlib.sha256() raw = video_tensor.cpu().numpy().tobytes() n = len(raw) @@ -87,6 +138,9 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None, h.update(str(round(mask_strength, 4)).encode()) h.update(str(mask_clip).encode()) h.update(str(mask_sync).encode()) + h.update(str(crop_to_mask).encode()) + if crop_to_mask: + h.update(str(round(crop_margin, 4)).encode()) h.update(prompt.encode()) h.update(str(fps).encode()) h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count @@ -131,6 +185,14 @@ class SelvaFeatureExtractor: "default": True, "tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.", }), + "crop_to_mask": ("BOOLEAN", { + "default": False, + "tooltip": "Experimental. When enabled, crops frames to a square region around the mask bounding box before resizing, instead of squishing the full frame. Requires mask. Combine with mask_clip/mask_sync for full isolation.", + }), + "crop_margin": ("FLOAT", { + "default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "Fraction of the bounding box side to add as padding around the crop. 0.1 = 10% margin on each side.", + }), }, } @@ -147,7 +209,8 @@ class SelvaFeatureExtractor: def extract_features(self, model, video, prompt, video_info=None, fps=30.0, duration=0.0, cache_dir="", mask=None, - mask_strength=1.0, mask_clip=True, mask_sync=True): + mask_strength=1.0, mask_clip=True, mask_sync=True, + crop_to_mask=False, crop_margin=0.1): if video_info is not None: fps = video_info["loaded_fps"] @@ -164,7 +227,8 @@ class SelvaFeatureExtractor: cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") os.makedirs(cache_dir, exist_ok=True) cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask, - mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync) + mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync, + crop_to_mask=crop_to_mask, crop_margin=crop_margin) cached_path = os.path.join(cache_dir, f"{cache_key}.npz") if os.path.exists(cached_path): @@ -186,10 +250,22 @@ class SelvaFeatureExtractor: print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) pbar = comfy.utils.ProgressBar(3) + # Pre-compute crop bbox once from the original-resolution mask + crop_bbox = None + if mask is not None and crop_to_mask: + H_vid, W_vid = video.shape[1], video.shape[2] + crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin) + cy0, cx0, cy1, cx1 = crop_bbox + print(f"[SelVA] Mask crop: y={cy0}:{cy1} x={cx0}:{cx1} " + f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True) + try: with torch.no_grad(): # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] + if crop_bbox is not None: + cy0, cx0, cy1, cx1 = crop_bbox + clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :] clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] if mask is not None and mask_clip: clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength) @@ -202,6 +278,9 @@ class SelvaFeatureExtractor: # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] + if crop_bbox is not None: + cy0, cx0, cy1, cx1 = crop_bbox + sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :] sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] if mask is not None and mask_sync: sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)