Files
ComfyUI-SelVA/nodes/selva_model_loader.py
T
Ethanfel 056a7b973d fix: enable VAE encoder in model loader — required for DITTO reference encoding
need_vae_encoder=False was deleting the encoder to save a small amount of VRAM.
DITTO now needs it to encode reference clips to latent space for style loss.
The spectrogram VAE encoder is small enough that the overhead is negligible.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-09 18:15:27 +02:00

172 lines
7.5 KiB
Python

import os
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
# Variant → (generator filename, mode, has_bigvgan)
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
}
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
_HF_REPO = "jnwnlee/SelVA"
# filename → (hf_repo_path, expected_md5 or None to skip check)
# Note: 44k generators are named 44khz in the HF repo; md5=None since the
# original download_utils had the wrong filenames so those md5s are unverified.
_WEIGHTS = {
"video_enc_sup_5.pth": ("weights/video_enc_sup_5.pth", "ff09a6dc36148536ee4db97eba081d05"),
"generator_small_16k_sup_5.pth": ("weights/generator_small_16k_sup_5.pth", "1cb0f0deec52de37f67b1fd9965337d0"),
"generator_small_44k_sup_5.pth": ("weights/generator_small_44khz_sup_5.pth", None),
"generator_medium_44k_sup_5.pth":("weights/generator_medium_44khz_sup_5.pth", None),
"generator_large_44k_sup_5.pth": ("weights/generator_large_44khz_sup_5.pth", None),
"v1-16.pth": ("ext_weights/v1-16.pth", "69f56803f59a549a1a507c93859fd4d7"),
"v1-44.pth": ("ext_weights/v1-44.pth", "fab020275fa44c6589820ce025191600"),
"best_netG.pt": ("ext_weights/best_netG.pt", "eeaf372a38a9c31c362120aba2dde292"),
"synchformer_state_dict.pth": ("ext_weights/synchformer_state_dict.pth", "5b2f5594b0730f70e41e549b7c94390c"),
}
def _md5(path):
import hashlib
h = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(8 * 1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def _ensure(filename, subdir=None):
"""Return path to weight file. Re-downloads if missing or MD5 mismatch."""
import shutil
from huggingface_hub import hf_hub_download
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
dest_path = dest_dir / filename
entry = _WEIGHTS.get(filename)
if entry is None:
raise ValueError(f"[SelVA] Unknown weight file: {filename}")
repo_path, expected_md5 = entry
if dest_path.exists():
if expected_md5 is None:
return str(dest_path)
actual = _md5(dest_path)
if actual == expected_md5:
return str(dest_path)
print(f"[SelVA] {filename}: MD5 mismatch ({actual}{expected_md5}), re-downloading...", flush=True)
dest_path.unlink()
print(f"[SelVA] Downloading {filename} from {_HF_REPO}...", flush=True)
dest_dir.mkdir(parents=True, exist_ok=True)
cached = hf_hub_download(repo_id=_HF_REPO, filename=repo_path)
shutil.copy2(cached, dest_path)
print(f"[SelVA] Saved to {dest_path}", flush=True)
return str(dest_path)
def _synchformer_path():
"""Return synchformer path, reusing models/prismaudio/ if already present."""
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
if prismaudio_path.exists():
return str(prismaudio_path)
return _ensure("synchformer_state_dict.pth")
class SelvaModelLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"variant": (list(_VARIANTS.keys()), {
"tooltip": "Model size and output sample rate. small_16k is fastest (16 kHz). 44k variants output 44.1 kHz. larger = better quality, more VRAM.",
}),
"precision": (["bf16", "fp16", "fp32"], {
"tooltip": "Compute dtype. bf16 is recommended on Ampere+ GPUs. fp16 for older NVIDIA hardware. fp32 if you see NaN outputs.",
}),
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"], {
"tooltip": "auto picks keep_in_vram if ≥16 GB VRAM is free, otherwise offload_to_cpu. offload_to_cpu moves weights to RAM between nodes, saving VRAM at the cost of speed.",
}),
}
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
FUNCTION = "load_model"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
def load_model(self, variant, precision, offload_strategy):
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.networks_video_enc import get_my_textsynch
from selva_core.model.utils.features_utils import FeaturesUtils
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
strategy = determine_offload_strategy(offload_strategy)
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
video_enc_path = _ensure("video_enc_sup_5.pth")
gen_path = _ensure(gen_filename)
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _ensure(vae_name, subdir="ext")
synch_path = _synchformer_path()
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
net_video_enc.load_weights(
torch.load(video_enc_path, map_location="cpu", weights_only=False)
)
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_path, map_location="cpu", weights_only=False)
)
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_path,
synchformer_ckpt=synch_path,
enable_conditions=True,
mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=True,
).to(device, dtype).eval()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
net_video_enc.to(get_offload_device())
feature_utils.to(get_offload_device())
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
return ({
"generator": net_generator,
"video_enc": net_video_enc,
"feature_utils": feature_utils,
"variant": variant,
"mode": mode,
"strategy": strategy,
"dtype": dtype,
"seq_cfg": seq_cfg,
},)