fix: cast discriminator inputs to match bfloat16 dtype in BigVGAN FM loss

The frozen discriminators are loaded in model dtype (bfloat16) but vocoder
waveform outputs are float32, causing a Conv2d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 11:36:02 +02:00
parent af6c225f53
commit c28e090196
+6 -4
View File
@@ -1161,11 +1161,13 @@ def _do_train(vocoder, mel_converter, clips,
if mpd is not None and mrd is not None: if mpd is not None and mrd is not None:
# Perceptual feature matching via frozen discriminators # Perceptual feature matching via frozen discriminators
# Discriminators are in model dtype (bfloat16); waveforms are float32
disc_dtype = next(mpd.parameters()).dtype
with torch.no_grad(): with torch.no_grad():
fmaps_real_mpd = mpd(target_t) fmaps_real_mpd = mpd(target_t.to(disc_dtype))
fmaps_real_mrd = mrd(target_t) fmaps_real_mrd = mrd(target_t.to(disc_dtype))
fmaps_gen_mpd = mpd(pred_t) fmaps_gen_mpd = mpd(pred_t.to(disc_dtype))
fmaps_gen_mrd = mrd(pred_t) fmaps_gen_mrd = mrd(pred_t.to(disc_dtype))
fm_loss = ( fm_loss = (
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) + _feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd) _feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)