diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index 43e6f0f..80a002e 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -35,7 +35,30 @@ def _resize_frames(frames, size): return x.clamp(0.0, 1.0) # [N, C, H, W] -def _hash_inputs(video_tensor, prompt, fps, duration, variant): +def _apply_mask(frames, mask): + """ + Apply a ComfyUI MASK to resized frames. + + frames: [N, C, H, W] float + mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame + + Resizes mask spatially with nearest-exact, samples temporally to N frames, + then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization). + """ + N, C, H, W = frames.shape + M = mask.shape[0] + mask_f = mask.float().unsqueeze(1) # [M, 1, H', W'] + if mask_f.shape[2] != H or mask_f.shape[3] != W: + mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W] + if M == 1: + mask_f = mask_f.expand(N, -1, -1, -1) + elif M != N: + indices = [min(int(i * M / N), M - 1) for i in range(N)] + mask_f = mask_f[indices] # [N, 1, H, W] + return frames * mask_f.to(frames.device) + + +def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None): h = hashlib.sha256() raw = video_tensor.cpu().numpy().tobytes() n = len(raw) @@ -43,6 +66,13 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant): h.update(raw[:chunk]) h.update(raw[n // 2: n // 2 + chunk]) h.update(raw[max(0, n - chunk):]) + if mask is not None: + raw_m = mask.cpu().numpy().tobytes() + nm = len(raw_m) + chunk_m = 256 * 1024 + h.update(raw_m[:chunk_m]) + h.update(raw_m[nm // 2: nm // 2 + chunk_m]) + h.update(raw_m[max(0, nm - chunk_m):]) h.update(prompt.encode()) h.update(str(fps).encode()) h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count @@ -72,6 +102,9 @@ class SelvaFeatureExtractor: "tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}), "cache_dir": ("STRING", {"default": "", "tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}), + "mask": ("MASK", { + "tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.", + }), }, } @@ -87,7 +120,7 @@ class SelvaFeatureExtractor: DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant." def extract_features(self, model, video, prompt, video_info=None, fps=30.0, - duration=0.0, cache_dir=""): + duration=0.0, cache_dir="", mask=None): if video_info is not None: fps = video_info["loaded_fps"] @@ -103,7 +136,7 @@ class SelvaFeatureExtractor: if not cache_dir: cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") os.makedirs(cache_dir, exist_ok=True) - cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"]) + cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask) cached_path = os.path.join(cache_dir, f"{cache_key}.npz") if os.path.exists(cached_path): @@ -129,8 +162,10 @@ class SelvaFeatureExtractor: # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] + if mask is not None: + clip_frames = _apply_mask(clip_frames, mask) clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384] - print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True) + print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True) clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] pbar.update(1) @@ -138,6 +173,8 @@ class SelvaFeatureExtractor: # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] + if mask is not None: + sync_frames = _apply_mask(sync_frames, mask) # Pad to minimum 16 frames (TextSynchformer segment size) if sync_frames.shape[0] < 16: pad = 16 - sync_frames.shape[0]