feat: improve mask support with neutral fill, mask_strength, and per-path toggles
- Replace zero-fill with neutral gray (0.5) fill so masked background pixels stay in-distribution: 0.5 maps to ~0 in CLIP normalized space and exactly 0 after sync's [-1,1] normalization - Add mask_strength float (0–1) for partial background suppression - Add mask_clip / mask_sync booleans to toggle masking independently on the CLIP (384px) and TextSynchformer (224px) encoding paths - Fix temporal mask sampling: use fps-accurate index formula (same as _sample_frames) instead of proportional int(i*M/N) - Include mask_strength, mask_clip, mask_sync in cache hash when mask is connected, so changing any param correctly busts the cache - Log lines now report masked/skipped state and strength per path Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -35,30 +35,41 @@ def _resize_frames(frames, size):
|
||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||
|
||||
|
||||
def _apply_mask(frames, mask):
|
||||
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
||||
"""
|
||||
Apply a ComfyUI MASK to resized frames.
|
||||
|
||||
frames: [N, C, H, W] float
|
||||
mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame
|
||||
frames: [N, C, H, W] float [0,1]
|
||||
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
|
||||
source_fps: original video fps (for accurate temporal sampling)
|
||||
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
|
||||
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
|
||||
|
||||
Resizes mask spatially with nearest-exact, samples temporally to N frames,
|
||||
then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization).
|
||||
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
|
||||
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
|
||||
"""
|
||||
N, C, H, W = frames.shape
|
||||
M = mask.shape[0]
|
||||
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
|
||||
if mask_f.shape[2] != H or mask_f.shape[3] != W:
|
||||
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
|
||||
|
||||
# Temporal sampling — use same index formula as _sample_frames for accuracy
|
||||
if M == 1:
|
||||
mask_f = mask_f.expand(N, -1, -1, -1)
|
||||
elif M != N:
|
||||
indices = [min(int(i * M / N), M - 1) for i in range(N)]
|
||||
else:
|
||||
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
|
||||
mask_f = mask_f[indices] # [N, 1, H, W]
|
||||
return frames * mask_f.to(frames.device)
|
||||
|
||||
mask_f = mask_f.to(frames.device)
|
||||
|
||||
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
|
||||
alpha = 1.0 - mask_strength * (1.0 - mask_f)
|
||||
return frames * alpha + 0.5 * (1.0 - alpha)
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
|
||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
||||
h = hashlib.sha256()
|
||||
raw = video_tensor.cpu().numpy().tobytes()
|
||||
n = len(raw)
|
||||
@@ -73,6 +84,9 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
|
||||
h.update(raw_m[:chunk_m])
|
||||
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
|
||||
h.update(raw_m[max(0, nm - chunk_m):])
|
||||
h.update(str(round(mask_strength, 4)).encode())
|
||||
h.update(str(mask_clip).encode())
|
||||
h.update(str(mask_sync).encode())
|
||||
h.update(prompt.encode())
|
||||
h.update(str(fps).encode())
|
||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||
@@ -105,6 +119,18 @@ class SelvaFeatureExtractor:
|
||||
"mask": ("MASK", {
|
||||
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
||||
}),
|
||||
"mask_strength": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
|
||||
}),
|
||||
"mask_clip": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
|
||||
}),
|
||||
"mask_sync": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -120,7 +146,8 @@ class SelvaFeatureExtractor:
|
||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir="", mask=None):
|
||||
duration=0.0, cache_dir="", mask=None,
|
||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
@@ -136,7 +163,8 @@ class SelvaFeatureExtractor:
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask)
|
||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
||||
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||
|
||||
if os.path.exists(cached_path):
|
||||
@@ -163,10 +191,11 @@ class SelvaFeatureExtractor:
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
if mask is not None:
|
||||
clip_frames = _apply_mask(clip_frames, mask)
|
||||
if mask is not None and mask_clip:
|
||||
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
@@ -174,8 +203,8 @@ class SelvaFeatureExtractor:
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if mask is not None:
|
||||
sync_frames = _apply_mask(sync_frames, mask)
|
||||
if mask is not None and mask_sync:
|
||||
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
@@ -185,7 +214,8 @@ class SelvaFeatureExtractor:
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {'(masked)' if mask is not None else ''}", flush=True)
|
||||
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
|
||||
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||
|
||||
Reference in New Issue
Block a user