Files
ComfyUI-SelVA/nodes/selva_feature_extractor.py
T
Ethanfel ab8e1e5b7b feat: SelvaFeatureExtractor outputs prompt as STRING
Users can now wire the prompt output directly to SelvaSampler's prompt input,
making the data flow explicit instead of relying on the implicit features fallback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-04 16:27:49 +02:00

176 lines
7.4 KiB
Python

import os
import hashlib
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
_CLIP_SIZE = 384
_SYNC_SIZE = 224
_CLIP_FPS = 8
_SYNC_FPS = 25
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
def _sample_frames(video, source_fps, target_fps, duration):
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
T = video.shape[0]
n_out = max(1, int(duration * target_fps))
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
return video[indices]
def _resize_frames(frames, size):
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
return x.clamp(0.0, 1.0) # [N, C, H, W]
def _hash_inputs(video_tensor, prompt, fps, duration, variant):
h = hashlib.sha256()
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
h.update(prompt.encode())
h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
h.update(variant.encode())
return h.hexdigest()[:16]
class SelvaFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"video": ("IMAGE",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Text prompt used by TextSynchformer to focus sync features on the relevant sound source. Should match the prompt used in SelvaSampler.",
}),
},
"optional": {
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info to auto-set fps."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001}),
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
"tooltip": "Override duration in seconds. 0 = infer from video length and fps."}),
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory for cached .npz features. Empty = temp dir."}),
},
}
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
RETURN_NAMES = ("features", "fps", "prompt")
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir=""):
if video_info is not None:
fps = video_info["loaded_fps"]
T = video.shape[0]
if duration <= 0:
duration = T / fps
duration = min(duration, T / fps) # clamp to actual video length
if not prompt.strip():
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
# Cache
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True)
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"])
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path)
return (cached, float(fps), cached.get("prompt", prompt))
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
feature_utils = model["feature_utils"]
net_video_enc = model["video_enc"]
if strategy == "offload_to_cpu":
feature_utils.to(device)
net_video_enc.to(device)
soft_empty_cache()
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
# Pad to minimum 16 frames (TextSynchformer segment size)
if sync_frames.shape[0] < 16:
pad = 16 - sync_frames.shape[0]
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
# Normalize [0,1] → [-1,1]
mean = _SYNC_MEAN.to(sync_frames.device)
std = _SYNC_STD.to(sync_frames.device)
sync_frames = (sync_frames - mean) / std
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px", flush=True)
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768]
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
if strategy == "offload_to_cpu":
feature_utils.to(get_offload_device())
net_video_enc.to(get_offload_device())
soft_empty_cache()
np.savez(
cached_path,
clip_features=clip_features.cpu().float().numpy(),
sync_features=sync_features.cpu().float().numpy(),
duration=float(duration),
prompt=np.array(prompt),
)
print(f"[SelVA] Features cached: {cached_path}", flush=True)
return ({
"clip_features": clip_features.cpu(),
"sync_features": sync_features.cpu(),
"duration": float(duration),
"prompt": prompt,
}, float(fps), prompt)
def _load_cached(path):
data = np.load(path, allow_pickle=False)
features = {
"clip_features": torch.from_numpy(data["clip_features"]),
"sync_features": torch.from_numpy(data["sync_features"]),
"duration": float(data["duration"]),
}
if "prompt" in data:
features["prompt"] = str(data["prompt"])
return features