fix: cast discriminator inputs to match bfloat16 dtype in BigVGAN FM loss
The frozen discriminators are loaded in model dtype (bfloat16) but vocoder waveform outputs are float32, causing a Conv2d dtype mismatch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1161,11 +1161,13 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
|
||||
if mpd is not None and mrd is not None:
|
||||
# Perceptual feature matching via frozen discriminators
|
||||
# Discriminators are in model dtype (bfloat16); waveforms are float32
|
||||
disc_dtype = next(mpd.parameters()).dtype
|
||||
with torch.no_grad():
|
||||
fmaps_real_mpd = mpd(target_t)
|
||||
fmaps_real_mrd = mrd(target_t)
|
||||
fmaps_gen_mpd = mpd(pred_t)
|
||||
fmaps_gen_mrd = mrd(pred_t)
|
||||
fmaps_real_mpd = mpd(target_t.to(disc_dtype))
|
||||
fmaps_real_mrd = mrd(target_t.to(disc_dtype))
|
||||
fmaps_gen_mpd = mpd(pred_t.to(disc_dtype))
|
||||
fmaps_gen_mrd = mrd(pred_t.to(disc_dtype))
|
||||
fm_loss = (
|
||||
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
|
||||
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
|
||||
|
||||
Reference in New Issue
Block a user