fix: detect silent discriminator load failure and fall back explicitly
If no matching key was found for MPD or MRD in the checkpoint, the for-loops completed silently and randomly-initialized discriminators were used as frozen feature extractors — producing meaningless feature matching loss while appearing to work. Now raises RuntimeError (caught by outer except) which triggers the existing fallback to mel+STFT losses with a clear warning. Also prints available checkpoint keys to help diagnose format mismatches. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -596,16 +596,26 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
mpd = _MultiPeriodDiscriminator()
|
mpd = _MultiPeriodDiscriminator()
|
||||||
mrd = _MultiResolutionDiscriminator()
|
mrd = _MultiResolutionDiscriminator()
|
||||||
# Try common key names used by different BigVGAN releases
|
# Try common key names used by different BigVGAN releases
|
||||||
|
mpd_loaded = False
|
||||||
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
|
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
|
||||||
if mpd_key in ckpt_d:
|
if mpd_key in ckpt_d:
|
||||||
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
|
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
|
||||||
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
|
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
|
||||||
|
mpd_loaded = True
|
||||||
break
|
break
|
||||||
|
mrd_loaded = False
|
||||||
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
|
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
|
||||||
if mrd_key in ckpt_d:
|
if mrd_key in ckpt_d:
|
||||||
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
|
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
|
||||||
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
|
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
|
||||||
|
mrd_loaded = True
|
||||||
break
|
break
|
||||||
|
if not (mpd_loaded and mrd_loaded):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[BigVGAN] Could not find discriminator keys in checkpoint. "
|
||||||
|
f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. "
|
||||||
|
f"Available keys: {list(ckpt_d.keys())}"
|
||||||
|
)
|
||||||
mpd.to(device).eval()
|
mpd.to(device).eval()
|
||||||
mrd.to(device).eval()
|
mrd.to(device).eval()
|
||||||
for p in mpd.parameters():
|
for p in mpd.parameters():
|
||||||
|
|||||||
Reference in New Issue
Block a user