diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index 958875c..43e6f0f 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -5,6 +5,7 @@ import tempfile import numpy as np import torch import torch.nn.functional as F +import comfy.utils from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache @@ -46,7 +47,7 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant): h.update(str(fps).encode()) h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count h.update(variant.encode()) - return h.hexdigest()[:16] + return h.hexdigest()[:32] class SelvaFeatureExtractor: @@ -76,8 +77,14 @@ class SelvaFeatureExtractor: RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING") RETURN_NAMES = ("features", "fps", "prompt") + OUTPUT_TOOLTIPS = ( + "Extracted feature bundle — connect to Sampler.", + "Source fps of the video — wire to VHS_VideoCombine frame_rate.", + "The prompt used during extraction — wire to Sampler prompt to avoid re-typing.", + ) FUNCTION = "extract_features" CATEGORY = SELVA_CATEGORY + DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant." def extract_features(self, model, video, prompt, video_info=None, fps=30.0, duration=0.0, cache_dir=""): @@ -116,6 +123,7 @@ class SelvaFeatureExtractor: soft_empty_cache() print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) + pbar = comfy.utils.ProgressBar(3) with torch.no_grad(): # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- @@ -125,6 +133,7 @@ class SelvaFeatureExtractor: print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True) clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024] + pbar.update(1) # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] @@ -142,10 +151,12 @@ class SelvaFeatureExtractor: # Encode T5 text + prepend supplementary tokens → text-conditioned sync features text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L] + pbar.update(1) text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask) sync_features = net_video_enc.encode_video_with_sync( sync_input, text_f=text_f, text_mask=text_mask ) # [1, T_sync, 768] + pbar.update(1) print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True) print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True) @@ -161,6 +172,7 @@ class SelvaFeatureExtractor: sync_features=sync_features.cpu().float().numpy(), duration=float(duration), prompt=np.array(prompt), + variant=np.array(model["variant"]), ) print(f"[SelVA] Features cached: {cached_path}", flush=True) @@ -169,6 +181,7 @@ class SelvaFeatureExtractor: "sync_features": sync_features.cpu(), "duration": float(duration), "prompt": prompt, + "variant": model["variant"], }, float(fps), prompt) @@ -181,4 +194,6 @@ def _load_cached(path): } if "prompt" in data: features["prompt"] = str(data["prompt"]) + if "variant" in data: + features["variant"] = str(data["variant"]) return features diff --git a/nodes/selva_model_loader.py b/nodes/selva_model_loader.py index 2eec4a8..3fc8497 100644 --- a/nodes/selva_model_loader.py +++ b/nodes/selva_model_loader.py @@ -101,8 +101,10 @@ class SelvaModelLoader: RETURN_TYPES = ("SELVA_MODEL",) RETURN_NAMES = ("model",) + OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",) FUNCTION = "load_model" CATEGORY = SELVA_CATEGORY + DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use." def load_model(self, variant, precision, offload_strategy): from selva_core.model.networks_generator import get_my_mmaudio @@ -112,9 +114,12 @@ class SelvaModelLoader: gen_filename, mode, has_bigvgan = _VARIANTS[variant] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported(): + print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True) + precision = "fp16" dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] strategy = determine_offload_strategy(offload_strategy) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True) video_enc_path = _ensure("video_enc_sup_5.pth") diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index fb9a40f..aa5fb83 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -1,5 +1,6 @@ import torch import comfy.utils +import comfy.model_management from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache @@ -29,15 +30,22 @@ class SelvaSampler: "tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 3–7."}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), }, - "optional": {}, + "optional": { + "normalize": ("BOOLEAN", { + "default": True, + "tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.", + }), + }, } RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) + OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",) FUNCTION = "generate" CATEGORY = SELVA_CATEGORY + DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance." - def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed): + def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True): import dataclasses from selva_core.model.flow_matching import FlowMatching @@ -48,6 +56,14 @@ class SelvaSampler: feature_utils = model["feature_utils"] mode = model["mode"] + # Validate that features were extracted with the same model variant + feat_variant = features.get("variant") + if feat_variant is not None and feat_variant != model["variant"]: + raise ValueError( + f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' " + f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model." + ) + # Resolve prompt: use override if given, otherwise fall back to features prompt if not prompt or not prompt.strip(): prompt = features.get("prompt", "") @@ -112,10 +128,17 @@ class SelvaSampler: pbar = comfy.utils.ProgressBar(steps) def ode_wrapper_tracked(t, x): + comfy.model_management.throw_exception_if_processing_interrupted() pbar.update(1) return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) - x1 = fm.to_data(ode_wrapper_tracked, x0) + try: + x1 = fm.to_data(ode_wrapper_tracked, x0) + except torch.cuda.OutOfMemoryError: + raise RuntimeError( + "[SelVA] CUDA out of memory during generation. Try switching offload_strategy " + "to 'offload_to_cpu', using a smaller variant, or reducing duration." + ) print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) @@ -137,8 +160,9 @@ class SelvaSampler: elif audio.dim() == 3 and audio.shape[1] != 1: audio = audio.mean(dim=1, keepdim=True) # stereo → mono - peak = audio.abs().max().clamp(min=1e-8) - audio = (audio / peak).clamp(-1, 1) + if normalize: + peak = audio.abs().max().clamp(min=1e-8) + audio = (audio / peak).clamp(-1, 1) print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)