Files
ComfyUI-SelVA/nodes/selva_bigvgan_loader.py
Ethanfel ecf828b007 fix: move vocoder to correct device after GAFilter injection
inject_gafilters creates Conv1d modules on CPU. load_state_dict
preserves existing param devices but GAFilter params stay on CPU,
causing device mismatch during vocode. Save target device before
injection, then move entire vocoder after loading.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-09 20:28:55 +02:00

78 lines
2.8 KiB
Python

"""SelVA BigVGAN Loader.
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
The model is modified in-place so ComfyUI's model cache is updated — no need to
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from .selva_bigvgan_trainer import inject_gafilters
class SelvaBigvganLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
DESCRIPTION = (
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
"Supports both 16k and 44k models. "
"Connect the output to SelVA Sampler instead of the base model loader."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"path": ("STRING", {
"default": "bigvgan_bj.pt",
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
"Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, model, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
if "generator" not in ckpt:
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
mode = model["mode"]
if mode == "16k":
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
elif mode == "44k":
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
# Remember device before injecting new modules (which default to CPU)
target_device = next(vocoder.parameters()).device
if ckpt.get("has_gafilter", False):
kernel_size = ckpt.get("gafilter_kernel_size", 9)
n_gaf = inject_gafilters(vocoder, kernel_size)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
vocoder.load_state_dict(ckpt["generator"])
vocoder.to(target_device)
vocoder.eval()
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
return (model,)