fix: detect silent discriminator load failure and fall back explicitly
If no matching key was found for MPD or MRD in the checkpoint, the for-loops completed silently and randomly-initialized discriminators were used as frozen feature extractors — producing meaningless feature matching loss while appearing to work. Now raises RuntimeError (caught by outer except) which triggers the existing fallback to mel+STFT losses with a clear warning. Also prints available checkpoint keys to help diagnose format mismatches. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -596,16 +596,26 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
mpd = _MultiPeriodDiscriminator()
|
||||
mrd = _MultiResolutionDiscriminator()
|
||||
# Try common key names used by different BigVGAN releases
|
||||
mpd_loaded = False
|
||||
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
|
||||
if mpd_key in ckpt_d:
|
||||
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
|
||||
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
|
||||
mpd_loaded = True
|
||||
break
|
||||
mrd_loaded = False
|
||||
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
|
||||
if mrd_key in ckpt_d:
|
||||
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
|
||||
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
|
||||
mrd_loaded = True
|
||||
break
|
||||
if not (mpd_loaded and mrd_loaded):
|
||||
raise RuntimeError(
|
||||
f"[BigVGAN] Could not find discriminator keys in checkpoint. "
|
||||
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
|
||||
f"Available keys: {list(ckpt_d.keys())}"
|
||||
)
|
||||
mpd.to(device).eval()
|
||||
mrd.to(device).eval()
|
||||
for p in mpd.parameters():
|
||||
|
||||
Reference in New Issue
Block a user