fix(bigvgan): add 44k/BigVGANv2 support to trainer and loader
44k variants use BigVGANv2 directly as the vocoder (no wrapper, no @inference_mode decorator), accessible at feature_utils.tod.vocoder. 16k wraps BigVGANVocoder inside BigVGAN, accessed at .vocoder.vocoder. Both trainer and loader now branch on model["mode"]. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -22,8 +22,9 @@ class SelvaBigvganLoader:
|
|||||||
RETURN_NAMES = ("model",)
|
RETURN_NAMES = ("model",)
|
||||||
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
|
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
|
||||||
DESCRIPTION = (
|
DESCRIPTION = (
|
||||||
"Loads a fine-tuned BigVGAN vocoder checkpoint from SelVA BigVGAN Trainer "
|
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
|
||||||
"and replaces the vocoder weights in the SELVA_MODEL. "
|
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
|
||||||
|
"Supports both 16k and 44k models. "
|
||||||
"Connect the output to SelVA Sampler instead of the base model loader."
|
"Connect the output to SelVA Sampler instead of the base model loader."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,16 +48,18 @@ class SelvaBigvganLoader:
|
|||||||
if not p.exists():
|
if not p.exists():
|
||||||
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
|
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
|
||||||
|
|
||||||
if model["mode"] != "16k":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"[BigVGAN] Fine-tuned loader only supports 16k mode."
|
|
||||||
)
|
|
||||||
|
|
||||||
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
|
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
|
||||||
if "generator" not in ckpt:
|
if "generator" not in ckpt:
|
||||||
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
|
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
|
||||||
|
|
||||||
vocoder = model["feature_utils"].tod.vocoder.vocoder
|
mode = model["mode"]
|
||||||
|
if mode == "16k":
|
||||||
|
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
|
||||||
|
elif mode == "44k":
|
||||||
|
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
|
||||||
|
else:
|
||||||
|
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
||||||
|
|
||||||
vocoder.load_state_dict(ckpt["generator"])
|
vocoder.load_state_dict(ckpt["generator"])
|
||||||
vocoder.eval()
|
vocoder.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,8 @@ class SelvaBigvganTrainer:
|
|||||||
DESCRIPTION = (
|
DESCRIPTION = (
|
||||||
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
|
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
|
||||||
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
|
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
|
||||||
"16k mode only. Load the result with SelVA BigVGAN Loader."
|
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. "
|
||||||
|
"Load the result with SelVA BigVGAN Loader."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -101,20 +102,22 @@ class SelvaBigvganTrainer:
|
|||||||
segment_seconds, save_every, seed):
|
segment_seconds, save_every, seed):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
if model["mode"] != "16k":
|
|
||||||
raise NotImplementedError(
|
|
||||||
"[BigVGAN] Only 16k mode is supported. "
|
|
||||||
"44k uses BigVGANv2 which requires a different training setup."
|
|
||||||
)
|
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
mode = model["mode"]
|
||||||
feature_utils = model["feature_utils"]
|
feature_utils = model["feature_utils"]
|
||||||
mel_converter = feature_utils.mel_converter
|
mel_converter = feature_utils.mel_converter
|
||||||
sample_rate = 16_000
|
|
||||||
strategy = model["strategy"]
|
strategy = model["strategy"]
|
||||||
|
|
||||||
# BigVGANVocoder nn.Module — bypass the @inference_mode wrapper on BigVGAN.forward
|
if mode == "16k":
|
||||||
|
# BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper
|
||||||
vocoder = feature_utils.tod.vocoder.vocoder
|
vocoder = feature_utils.tod.vocoder.vocoder
|
||||||
|
sample_rate = 16_000
|
||||||
|
elif mode == "44k":
|
||||||
|
# BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator
|
||||||
|
vocoder = feature_utils.tod.vocoder
|
||||||
|
sample_rate = 44_100
|
||||||
|
else:
|
||||||
|
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
||||||
|
|
||||||
# Resolve paths
|
# Resolve paths
|
||||||
data_dir = Path(data_dir.strip())
|
data_dir = Path(data_dir.strip())
|
||||||
|
|||||||
Reference in New Issue
Block a user