fix: detect silent discriminator load failure and fall back explicitly

If no matching key was found for MPD or MRD in the checkpoint, the for-loops
completed silently and randomly-initialized discriminators were used as frozen
feature extractors — producing meaningless feature matching loss while
appearing to work. Now raises RuntimeError (caught by outer except) which
triggers the existing fallback to mel+STFT losses with a clear warning.
Also prints available checkpoint keys to help diagnose format mismatches.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 14:39:55 +02:00
parent f50afa9796
commit b9f95cfd7e
+10
View File
@@ -596,16 +596,26 @@ def _do_train(vocoder, mel_converter, clips,
mpd = _MultiPeriodDiscriminator()
mrd = _MultiResolutionDiscriminator()
# Try common key names used by different BigVGAN releases
mpd_loaded = False
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
if mpd_key in ckpt_d:
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
mpd_loaded = True
break
mrd_loaded = False
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
if mrd_key in ckpt_d:
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
mrd_loaded = True
break
if not (mpd_loaded and mrd_loaded):
raise RuntimeError(
f"[BigVGAN] Could not find discriminator keys in checkpoint. "
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
f"Available keys: {list(ckpt_d.keys())}"
)
mpd.to(device).eval()
mrd.to(device).eval()
for p in mpd.parameters():