feat: add optional MASK input to SelVA Feature Extractor
Allows per-frame or static segmentation masks to be applied before CLIP and sync encoding, zeroing background pixels. Useful when multiple objects compete for the same sound and text prompting alone is insufficient. - _apply_mask(): resizes mask spatially (nearest-exact), samples temporally to match sampled frame count, multiplies into frames - _hash_inputs(): includes mask bytes in cache key (begin/mid/end sampling) - INPUT_TYPES: mask added to optional inputs with tooltip - extract_features(): mask=None parameter, applied after _resize_frames for both CLIP (384px) and sync (224px) paths, before normalization - Log line notes when masking is active Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -35,7 +35,30 @@ def _resize_frames(frames, size):
|
||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant):
|
||||
def _apply_mask(frames, mask):
|
||||
"""
|
||||
Apply a ComfyUI MASK to resized frames.
|
||||
|
||||
frames: [N, C, H, W] float
|
||||
mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame
|
||||
|
||||
Resizes mask spatially with nearest-exact, samples temporally to N frames,
|
||||
then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization).
|
||||
"""
|
||||
N, C, H, W = frames.shape
|
||||
M = mask.shape[0]
|
||||
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
|
||||
if mask_f.shape[2] != H or mask_f.shape[3] != W:
|
||||
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
|
||||
if M == 1:
|
||||
mask_f = mask_f.expand(N, -1, -1, -1)
|
||||
elif M != N:
|
||||
indices = [min(int(i * M / N), M - 1) for i in range(N)]
|
||||
mask_f = mask_f[indices] # [N, 1, H, W]
|
||||
return frames * mask_f.to(frames.device)
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
|
||||
h = hashlib.sha256()
|
||||
raw = video_tensor.cpu().numpy().tobytes()
|
||||
n = len(raw)
|
||||
@@ -43,6 +66,13 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant):
|
||||
h.update(raw[:chunk])
|
||||
h.update(raw[n // 2: n // 2 + chunk])
|
||||
h.update(raw[max(0, n - chunk):])
|
||||
if mask is not None:
|
||||
raw_m = mask.cpu().numpy().tobytes()
|
||||
nm = len(raw_m)
|
||||
chunk_m = 256 * 1024
|
||||
h.update(raw_m[:chunk_m])
|
||||
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
|
||||
h.update(raw_m[max(0, nm - chunk_m):])
|
||||
h.update(prompt.encode())
|
||||
h.update(str(fps).encode())
|
||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||
@@ -72,6 +102,9 @@ class SelvaFeatureExtractor:
|
||||
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
||||
"cache_dir": ("STRING", {"default": "",
|
||||
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
||||
"mask": ("MASK", {
|
||||
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -87,7 +120,7 @@ class SelvaFeatureExtractor:
|
||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir=""):
|
||||
duration=0.0, cache_dir="", mask=None):
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
@@ -103,7 +136,7 @@ class SelvaFeatureExtractor:
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"])
|
||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||
|
||||
if os.path.exists(cached_path):
|
||||
@@ -129,8 +162,10 @@ class SelvaFeatureExtractor:
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
if mask is not None:
|
||||
clip_frames = _apply_mask(clip_frames, mask)
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True)
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
@@ -138,6 +173,8 @@ class SelvaFeatureExtractor:
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if mask is not None:
|
||||
sync_frames = _apply_mask(sync_frames, mask)
|
||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user