feat: comprehensive node improvements

Model Loader:
- bf16 support check — auto-falls back to fp16 on unsupported GPUs
- DESCRIPTION and OUTPUT_TOOLTIPS

Feature Extractor:
- Store variant in features dict and .npz cache
- Progress bar (3 steps: CLIP encode, T5 encode, sync encode)
- Expand cache hash to 32 hex chars
- DESCRIPTION and OUTPUT_TOOLTIPS

Sampler:
- Variant mismatch validation against extracted features
- Cancellation support via throw_exception_if_processing_interrupted()
- OOM catch with actionable error message
- normalize toggle (optional BOOLEAN, default true) for peak normalization
- Remove empty optional: {} block
- DESCRIPTION and OUTPUT_TOOLTIPS

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 18:16:03 +02:00
parent 429810db5b
commit bd53744e2d
3 changed files with 51 additions and 7 deletions
+16 -1
View File
@@ -5,6 +5,7 @@ import tempfile
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import comfy.utils
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
@@ -46,7 +47,7 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant):
h.update(str(fps).encode()) h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
h.update(variant.encode()) h.update(variant.encode())
return h.hexdigest()[:16] return h.hexdigest()[:32]
class SelvaFeatureExtractor: class SelvaFeatureExtractor:
@@ -76,8 +77,14 @@ class SelvaFeatureExtractor:
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING") RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
RETURN_NAMES = ("features", "fps", "prompt") RETURN_NAMES = ("features", "fps", "prompt")
OUTPUT_TOOLTIPS = (
"Extracted feature bundle — connect to Sampler.",
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
)
FUNCTION = "extract_features" FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0, def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir=""): duration=0.0, cache_dir=""):
@@ -116,6 +123,7 @@ class SelvaFeatureExtractor:
soft_empty_cache() soft_empty_cache()
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3)
with torch.no_grad(): with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
@@ -125,6 +133,7 @@ class SelvaFeatureExtractor:
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True) print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True)
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
pbar.update(1)
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
@@ -142,10 +151,12 @@ class SelvaFeatureExtractor:
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features # Encode T5 text + prepend supplementary tokens → text-conditioned sync features
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L] text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
pbar.update(1)
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask) text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
sync_features = net_video_enc.encode_video_with_sync( sync_features = net_video_enc.encode_video_with_sync(
sync_input, text_f=text_f, text_mask=text_mask sync_input, text_f=text_f, text_mask=text_mask
) # [1, T_sync, 768] ) # [1, T_sync, 768]
pbar.update(1)
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True) print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True) print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
@@ -161,6 +172,7 @@ class SelvaFeatureExtractor:
sync_features=sync_features.cpu().float().numpy(), sync_features=sync_features.cpu().float().numpy(),
duration=float(duration), duration=float(duration),
prompt=np.array(prompt), prompt=np.array(prompt),
variant=np.array(model["variant"]),
) )
print(f"[SelVA] Features cached: {cached_path}", flush=True) print(f"[SelVA] Features cached: {cached_path}", flush=True)
@@ -169,6 +181,7 @@ class SelvaFeatureExtractor:
"sync_features": sync_features.cpu(), "sync_features": sync_features.cpu(),
"duration": float(duration), "duration": float(duration),
"prompt": prompt, "prompt": prompt,
"variant": model["variant"],
}, float(fps), prompt) }, float(fps), prompt)
@@ -181,4 +194,6 @@ def _load_cached(path):
} }
if "prompt" in data: if "prompt" in data:
features["prompt"] = str(data["prompt"]) features["prompt"] = str(data["prompt"])
if "variant" in data:
features["variant"] = str(data["variant"])
return features return features
+6 -1
View File
@@ -101,8 +101,10 @@ class SelvaModelLoader:
RETURN_TYPES = ("SELVA_MODEL",) RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",) RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
FUNCTION = "load_model" FUNCTION = "load_model"
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
def load_model(self, variant, precision, offload_strategy): def load_model(self, variant, precision, offload_strategy):
from selva_core.model.networks_generator import get_my_mmaudio from selva_core.model.networks_generator import get_my_mmaudio
@@ -112,9 +114,12 @@ class SelvaModelLoader:
gen_filename, mode, has_bigvgan = _VARIANTS[variant] gen_filename, mode, has_bigvgan = _VARIANTS[variant]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
strategy = determine_offload_strategy(offload_strategy) strategy = determine_offload_strategy(offload_strategy)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True) print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
video_enc_path = _ensure("video_enc_sup_5.pth") video_enc_path = _ensure("video_enc_sup_5.pth")
+26 -2
View File
@@ -1,5 +1,6 @@
import torch import torch
import comfy.utils import comfy.utils
import comfy.model_management
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
@@ -29,15 +30,22 @@ class SelvaSampler:
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 37."}), "tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 37."}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
}, },
"optional": {}, "optional": {
"normalize": ("BOOLEAN", {
"default": True,
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
}),
},
} }
RETURN_TYPES = ("AUDIO",) RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",) RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
FUNCTION = "generate" FUNCTION = "generate"
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed): def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
import dataclasses import dataclasses
from selva_core.model.flow_matching import FlowMatching from selva_core.model.flow_matching import FlowMatching
@@ -48,6 +56,14 @@ class SelvaSampler:
feature_utils = model["feature_utils"] feature_utils = model["feature_utils"]
mode = model["mode"] mode = model["mode"]
# Validate that features were extracted with the same model variant
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
)
# Resolve prompt: use override if given, otherwise fall back to features prompt # Resolve prompt: use override if given, otherwise fall back to features prompt
if not prompt or not prompt.strip(): if not prompt or not prompt.strip():
prompt = features.get("prompt", "") prompt = features.get("prompt", "")
@@ -112,10 +128,17 @@ class SelvaSampler:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
def ode_wrapper_tracked(t, x): def ode_wrapper_tracked(t, x):
comfy.model_management.throw_exception_if_processing_interrupted()
pbar.update(1) pbar.update(1)
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
try:
x1 = fm.to_data(ode_wrapper_tracked, x0) x1 = fm.to_data(ode_wrapper_tracked, x0)
except torch.cuda.OutOfMemoryError:
raise RuntimeError(
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
)
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
@@ -137,6 +160,7 @@ class SelvaSampler:
elif audio.dim() == 3 and audio.shape[1] != 1: elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True) # stereo → mono audio = audio.mean(dim=1, keepdim=True) # stereo → mono
if normalize:
peak = audio.abs().max().clamp(min=1e-8) peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1) audio = (audio / peak).clamp(-1, 1)
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)