feat: comprehensive node improvements
Model Loader:
- bf16 support check — auto-falls back to fp16 on unsupported GPUs
- DESCRIPTION and OUTPUT_TOOLTIPS
Feature Extractor:
- Store variant in features dict and .npz cache
- Progress bar (3 steps: CLIP encode, T5 encode, sync encode)
- Expand cache hash to 32 hex chars
- DESCRIPTION and OUTPUT_TOOLTIPS
Sampler:
- Variant mismatch validation against extracted features
- Cancellation support via throw_exception_if_processing_interrupted()
- OOM catch with actionable error message
- normalize toggle (optional BOOLEAN, default true) for peak normalization
- Remove empty optional: {} block
- DESCRIPTION and OUTPUT_TOOLTIPS
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import tempfile
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.utils
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
@@ -46,7 +47,7 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant):
|
||||
h.update(str(fps).encode())
|
||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||
h.update(variant.encode())
|
||||
return h.hexdigest()[:16]
|
||||
return h.hexdigest()[:32]
|
||||
|
||||
|
||||
class SelvaFeatureExtractor:
|
||||
@@ -76,8 +77,14 @@ class SelvaFeatureExtractor:
|
||||
|
||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
||||
RETURN_NAMES = ("features", "fps", "prompt")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"Extracted feature bundle — connect to Sampler.",
|
||||
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
||||
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
||||
)
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir=""):
|
||||
@@ -116,6 +123,7 @@ class SelvaFeatureExtractor:
|
||||
soft_empty_cache()
|
||||
|
||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||
pbar = comfy.utils.ProgressBar(3)
|
||||
|
||||
with torch.no_grad():
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
@@ -125,6 +133,7 @@ class SelvaFeatureExtractor:
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
@@ -142,10 +151,12 @@ class SelvaFeatureExtractor:
|
||||
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||
pbar.update(1)
|
||||
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_f, text_mask=text_mask
|
||||
) # [1, T_sync, 768]
|
||||
pbar.update(1)
|
||||
|
||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||
@@ -161,6 +172,7 @@ class SelvaFeatureExtractor:
|
||||
sync_features=sync_features.cpu().float().numpy(),
|
||||
duration=float(duration),
|
||||
prompt=np.array(prompt),
|
||||
variant=np.array(model["variant"]),
|
||||
)
|
||||
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||
|
||||
@@ -169,6 +181,7 @@ class SelvaFeatureExtractor:
|
||||
"sync_features": sync_features.cpu(),
|
||||
"duration": float(duration),
|
||||
"prompt": prompt,
|
||||
"variant": model["variant"],
|
||||
}, float(fps), prompt)
|
||||
|
||||
|
||||
@@ -181,4 +194,6 @@ def _load_cached(path):
|
||||
}
|
||||
if "prompt" in data:
|
||||
features["prompt"] = str(data["prompt"])
|
||||
if "variant" in data:
|
||||
features["variant"] = str(data["variant"])
|
||||
return features
|
||||
|
||||
@@ -101,8 +101,10 @@ class SelvaModelLoader:
|
||||
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
|
||||
|
||||
def load_model(self, variant, precision, offload_strategy):
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
@@ -112,9 +114,12 @@ class SelvaModelLoader:
|
||||
|
||||
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
|
||||
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
|
||||
precision = "fp16"
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
|
||||
video_enc_path = _ensure("video_enc_sup_5.pth")
|
||||
|
||||
+26
-2
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
@@ -29,15 +30,22 @@ class SelvaSampler:
|
||||
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 3–7."}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
"optional": {},
|
||||
"optional": {
|
||||
"normalize": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
|
||||
import dataclasses
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
@@ -48,6 +56,14 @@ class SelvaSampler:
|
||||
feature_utils = model["feature_utils"]
|
||||
mode = model["mode"]
|
||||
|
||||
# Validate that features were extracted with the same model variant
|
||||
feat_variant = features.get("variant")
|
||||
if feat_variant is not None and feat_variant != model["variant"]:
|
||||
raise ValueError(
|
||||
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
|
||||
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
|
||||
)
|
||||
|
||||
# Resolve prompt: use override if given, otherwise fall back to features prompt
|
||||
if not prompt or not prompt.strip():
|
||||
prompt = features.get("prompt", "")
|
||||
@@ -112,10 +128,17 @@ class SelvaSampler:
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
def ode_wrapper_tracked(t, x):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
pbar.update(1)
|
||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||
|
||||
try:
|
||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
@@ -137,6 +160,7 @@ class SelvaSampler:
|
||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||
|
||||
if normalize:
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||
|
||||
Reference in New Issue
Block a user