Files
ComfyUI-SelVA/nodes/selva_feature_extractor.py
T
Ethanfel f4a7292cde feat: add optional MASK input to SelVA Feature Extractor
Allows per-frame or static segmentation masks to be applied before CLIP
and sync encoding, zeroing background pixels. Useful when multiple objects
compete for the same sound and text prompting alone is insufficient.

- _apply_mask(): resizes mask spatially (nearest-exact), samples temporally
  to match sampled frame count, multiplies into frames
- _hash_inputs(): includes mask bytes in cache key (begin/mid/end sampling)
- INPUT_TYPES: mask added to optional inputs with tooltip
- extract_features(): mask=None parameter, applied after _resize_frames for
  both CLIP (384px) and sync (224px) paths, before normalization
- Log line notes when masking is active

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 08:34:13 +02:00

237 lines
10 KiB
Python

import os
import hashlib
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
import comfy.utils
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
_CLIP_SIZE = 384
_SYNC_SIZE = 224
_CLIP_FPS = 8
_SYNC_FPS = 25
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
def _sample_frames(video, source_fps, target_fps, duration):
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
T = video.shape[0]
n_out = max(1, int(duration * target_fps))
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
return video[indices]
def _resize_frames(frames, size):
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
return x.clamp(0.0, 1.0) # [N, C, H, W]
def _apply_mask(frames, mask):
"""
Apply a ComfyUI MASK to resized frames.
frames: [N, C, H, W] float
mask: [M, H', W'] float [0,1] — M=1 static or M≥N per-frame
Resizes mask spatially with nearest-exact, samples temporally to N frames,
then multiplies. Background pixels become 0 (→ -1 after [-1,1] normalization).
"""
N, C, H, W = frames.shape
M = mask.shape[0]
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
if mask_f.shape[2] != H or mask_f.shape[3] != W:
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
if M == 1:
mask_f = mask_f.expand(N, -1, -1, -1)
elif M != N:
indices = [min(int(i * M / N), M - 1) for i in range(N)]
mask_f = mask_f[indices] # [N, 1, H, W]
return frames * mask_f.to(frames.device)
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None):
h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes()
n = len(raw)
chunk = 512 * 1024 # 512 KB per sample
h.update(raw[:chunk])
h.update(raw[n // 2: n // 2 + chunk])
h.update(raw[max(0, n - chunk):])
if mask is not None:
raw_m = mask.cpu().numpy().tobytes()
nm = len(raw_m)
chunk_m = 256 * 1024
h.update(raw_m[:chunk_m])
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
h.update(raw_m[max(0, nm - chunk_m):])
h.update(prompt.encode())
h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
h.update(variant.encode())
return h.hexdigest()[:32]
class SelvaFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"video": ("IMAGE",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Describes the sounds to generate. Used to focus the visual sync features on motion relevant to the prompt — more specific prompts produce cleaner audio sync. Wire the prompt output directly to the Sampler so you only type it once.",
}),
},
"optional": {
"video_info": ("VHS_VIDEOINFO", {
"tooltip": "VHS_VIDEOINFO from VHS LoadVideo. Automatically sets the correct source fps — always connect this when loading video with VHS nodes.",
}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001,
"tooltip": "Source fps of the input video. Ignored when video_info is connected."}),
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
"cache_dir": ("STRING", {"default": "",
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
"mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}),
},
}
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
RETURN_NAMES = ("features", "fps", "prompt")
OUTPUT_TOOLTIPS = (
"Extracted feature bundle — connect to Sampler.",
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
)
FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", mask=None):
if video_info is not None:
fps = video_info["loaded_fps"]
T = video.shape[0]
if duration <= 0:
duration = T / fps
duration = min(duration, T / fps) # clamp to actual video length
if not prompt.strip():
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
# Cache
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True)
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path)
return (cached, float(fps), cached.get("prompt", prompt))
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
feature_utils = model["feature_utils"]
net_video_enc = model["video_enc"]
if strategy == "offload_to_cpu":
feature_utils.to(device)
net_video_enc.to(device)
soft_empty_cache()
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3)
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None:
clip_frames = _apply_mask(clip_frames, mask)
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {'(masked)' if mask is not None else ''}", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None:
sync_frames = _apply_mask(sync_frames, mask)
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
# Normalize [0,1] → [-1,1]
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px", flush=True)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
pbar.update(1)
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768]
pbar.update(1)
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
if strategy == "offload_to_cpu":
feature_utils.to(get_offload_device())
net_video_enc.to(get_offload_device())
soft_empty_cache()
np.savez(
cached_path,
clip_features=clip_features.cpu().float().numpy(),
sync_features=sync_features.cpu().float().numpy(),
duration=float(duration),
prompt=np.array(prompt),
variant=np.array(model["variant"]),
)
print(f"[SelVA] Features cached: {cached_path}", flush=True)
return ({
"clip_features": clip_features.cpu(),
"sync_features": sync_features.cpu(),
"duration": float(duration),
"prompt": prompt,
"variant": model["variant"],
}, float(fps), prompt)
def _load_cached(path):
data = np.load(path, allow_pickle=False)
features = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
"duration": float(data["duration"]),
}
if "prompt" in data:
features["prompt"] = str(data["prompt"])
if "variant" in data:
features["variant"] = str(data["variant"])
return features