From f28759f1e3b6fbc6e9a0ef3cf62af7260796caa9 Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Sun, 5 Apr 2026 10:43:01 +0200 Subject: [PATCH] feat: improve mask support with neutral fill, mask_strength, and per-path toggles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace zero-fill with neutral gray (0.5) fill so masked background pixels stay in-distribution: 0.5 maps to ~0 in CLIP normalized space and exactly 0 after sync's [-1,1] normalization - Add mask_strength float (0–1) for partial background suppression - Add mask_clip / mask_sync booleans to toggle masking independently on the CLIP (384px) and TextSynchformer (224px) encoding paths - Fix temporal mask sampling: use fps-accurate index formula (same as _sample_frames) instead of proportional int(i*M/N) - Include mask_strength, mask_clip, mask_sync in cache hash when mask is connected, so changing any param correctly busts the cache - Log lines now report masked/skipped state and strength per path Co-Authored-By: Claude Sonnet 4.6 --- nodes/selva_feature_extractor.py | 64 +++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index c4a9f4a..ada6d8a 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -35,30 +35,41 @@ def _resize_frames(frames, size): return x.clamp(0.0, 1.0) # [N, C, H, W] -def _apply_mask(frames, mask): +def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0): """ Apply a ComfyUI MASK to resized frames. - frames: [N, C, H, W] float - mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame + frames: [N, C, H, W] float [0,1] + mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame + source_fps: original video fps (for accurate temporal sampling) + target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS) + mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray) - Resizes mask spatially with nearest-exact, samples temporally to N frames, - then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization). + Background pixels are filled with 0.5 rather than 0 — less out-of-distribution + for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path. """ N, C, H, W = frames.shape M = mask.shape[0] mask_f = mask.float().unsqueeze(1) # [M, 1, H', W'] if mask_f.shape[2] != H or mask_f.shape[3] != W: mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W] + + # Temporal sampling — use same index formula as _sample_frames for accuracy if M == 1: mask_f = mask_f.expand(N, -1, -1, -1) - elif M != N: - indices = [min(int(i * M / N), M - 1) for i in range(N)] + else: + indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)] mask_f = mask_f[indices] # [N, 1, H, W] - return frames * mask_f.to(frames.device) + + mask_f = mask_f.to(frames.device) + + # alpha=1 on foreground, (1-strength) on background → blend toward neutral gray + alpha = 1.0 - mask_strength * (1.0 - mask_f) + return frames * alpha + 0.5 * (1.0 - alpha) -def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None): +def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None, + mask_strength=1.0, mask_clip=True, mask_sync=True): h = hashlib.sha256() raw = video_tensor.cpu().numpy().tobytes() n = len(raw) @@ -73,6 +84,9 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None): h.update(raw_m[:chunk_m]) h.update(raw_m[nm // 2: nm // 2 + chunk_m]) h.update(raw_m[max(0, nm - chunk_m):]) + h.update(str(round(mask_strength, 4)).encode()) + h.update(str(mask_clip).encode()) + h.update(str(mask_sync).encode()) h.update(prompt.encode()) h.update(str(fps).encode()) h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count @@ -105,6 +119,18 @@ class SelvaFeatureExtractor: "mask": ("MASK", { "tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.", }), + "mask_strength": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05, + "tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.", + }), + "mask_clip": ("BOOLEAN", { + "default": True, + "tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.", + }), + "mask_sync": ("BOOLEAN", { + "default": True, + "tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.", + }), }, } @@ -120,7 +146,8 @@ class SelvaFeatureExtractor: DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant." def extract_features(self, model, video, prompt, video_info=None, fps=30.0, - duration=0.0, cache_dir="", mask=None): + duration=0.0, cache_dir="", mask=None, + mask_strength=1.0, mask_clip=True, mask_sync=True): if video_info is not None: fps = video_info["loaded_fps"] @@ -136,7 +163,8 @@ class SelvaFeatureExtractor: if not cache_dir: cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") os.makedirs(cache_dir, exist_ok=True) - cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask) + cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask, + mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync) cached_path = os.path.join(cache_dir, f"{cache_key}.npz") if os.path.exists(cached_path): @@ -163,10 +191,11 @@ class SelvaFeatureExtractor: # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] - if mask is not None: - clip_frames = _apply_mask(clip_frames, mask) + if mask is not None and mask_clip: + clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength) clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] - print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True) + _clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "") + print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True) clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] pbar.update(1) @@ -174,8 +203,8 @@ class SelvaFeatureExtractor: # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] - if mask is not None: - sync_frames = _apply_mask(sync_frames, mask) + if mask is not None and mask_sync: + sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength) # Pad to minimum 16 frames (TextSynchformer segment size) if sync_frames.shape[0] < 16: pad = 16 - sync_frames.shape[0] @@ -185,7 +214,8 @@ class SelvaFeatureExtractor: std = _SYNC_STD.to(sync_frames.device) sync_frames = (sync_frames - mean) / std sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224] - print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True) + _sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "") + print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True) # Encode T5 text + prepend supplementary tokens → text-conditioned sync features text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]