fix(bigvgan): add 44k/BigVGANv2 support to trainer and loader

44k variants use BigVGANv2 directly as the vocoder (no wrapper, no
@inference_mode decorator), accessible at feature_utils.tod.vocoder.
16k wraps BigVGANVocoder inside BigVGAN, accessed at .vocoder.vocoder.
Both trainer and loader now branch on model["mode"].

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 01:28:32 +02:00
parent 9c784b4bdb
commit 790a53e3df
2 changed files with 28 additions and 22 deletions
+17 -14
View File
@@ -63,7 +63,8 @@ class SelvaBigvganTrainer:
DESCRIPTION = (
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
"16k mode only. Load the result with SelVA BigVGAN Loader."
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. "
"Load the result with SelVA BigVGAN Loader."
)
@classmethod
@@ -101,20 +102,22 @@ class SelvaBigvganTrainer:
segment_seconds, save_every, seed):
import traceback
if model["mode"] != "16k":
raise NotImplementedError(
"[BigVGAN] Only 16k mode is supported. "
"44k uses BigVGANv2 which requires a different training setup."
)
device = get_device()
mode = model["mode"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
device = get_device()
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
sample_rate = 16_000
strategy = model["strategy"]
# BigVGANVocoder nn.Module — bypass the @inference_mode wrapper on BigVGAN.forward
vocoder = feature_utils.tod.vocoder.vocoder
if mode == "16k":
# BigVGANVocoder wrapped inside BigVGAN — bypass the @inference_mode on the wrapper
vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
# BigVGANv2 is the vocoder directly (no wrapper); no @inference_mode decorator
vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
# Resolve paths
data_dir = Path(data_dir.strip())