feat: improve mask support with neutral fill, mask_strength, and per-path toggles

- Replace zero-fill with neutral gray (0.5) fill so masked background
  pixels stay in-distribution: 0.5 maps to ~0 in CLIP normalized space
  and exactly 0 after sync's [-1,1] normalization
- Add mask_strength float (0–1) for partial background suppression
- Add mask_clip / mask_sync booleans to toggle masking independently
  on the CLIP (384px) and TextSynchformer (224px) encoding paths
- Fix temporal mask sampling: use fps-accurate index formula (same as
  _sample_frames) instead of proportional int(i*M/N)
- Include mask_strength, mask_clip, mask_sync in cache hash when mask
  is connected, so changing any param correctly busts the cache
- Log lines now report masked/skipped state and strength per path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-05 10:43:01 +02:00
parent 3dd6badfd9
commit f28759f1e3
+47 -17
View File
@@ -35,30 +35,41 @@ def _resize_frames(frames, size):
return x.clamp(0.0, 1.0) # [N, C, H, W]
def _apply_mask(frames, mask):
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
"""
Apply a ComfyUI MASK to resized frames.
frames: [N, C, H, W] float
mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame
frames: [N, C, H, W] float [0,1]
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
source_fps: original video fps (for accurate temporal sampling)
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
Resizes mask spatially with nearest-exact, samples temporally to N frames,
then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization).
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
"""
N, C, H, W = frames.shape
M = mask.shape[0]
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
if mask_f.shape[2] != H or mask_f.shape[3] != W:
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
# Temporal sampling — use same index formula as _sample_frames for accuracy
if M == 1:
mask_f = mask_f.expand(N, -1, -1, -1)
elif M != N:
indices = [min(int(i * M / N), M - 1) for i in range(N)]
else:
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
mask_f = mask_f[indices] # [N, 1, H, W]
return frames * mask_f.to(frames.device)
mask_f = mask_f.to(frames.device)
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
alpha = 1.0 - mask_strength * (1.0 - mask_f)
return frames * alpha + 0.5 * (1.0 - alpha)
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes()
n = len(raw)
@@ -73,6 +84,9 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
h.update(raw_m[:chunk_m])
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
h.update(raw_m[max(0, nm - chunk_m):])
h.update(str(round(mask_strength, 4)).encode())
h.update(str(mask_clip).encode())
h.update(str(mask_sync).encode())
h.update(prompt.encode())
h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
@@ -105,6 +119,18 @@ class SelvaFeatureExtractor:
"mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}),
"mask_strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
}),
"mask_clip": ("BOOLEAN", {
"default": True,
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
}),
"mask_sync": ("BOOLEAN", {
"default": True,
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
}),
},
}
@@ -120,7 +146,8 @@ class SelvaFeatureExtractor:
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", mask=None):
duration=0.0, cache_dir="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True):
if video_info is not None:
fps = video_info["loaded_fps"]
@@ -136,7 +163,8 @@ class SelvaFeatureExtractor:
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True)
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask)
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
@@ -163,10 +191,11 @@ class SelvaFeatureExtractor:
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None:
clip_frames = _apply_mask(clip_frames, mask)
if mask is not None and mask_clip:
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
@@ -174,8 +203,8 @@ class SelvaFeatureExtractor:
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None:
sync_frames = _apply_mask(sync_frames, mask)
if mask is not None and mask_sync:
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
@@ -185,7 +214,8 @@ class SelvaFeatureExtractor:
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]