diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 01853bc..965d3ad 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1161,11 +1161,13 @@ def _do_train(vocoder, mel_converter, clips, if mpd is not None and mrd is not None: # Perceptual feature matching via frozen discriminators + # Discriminators are in model dtype (bfloat16); waveforms are float32 + disc_dtype = next(mpd.parameters()).dtype with torch.no_grad(): - fmaps_real_mpd = mpd(target_t) - fmaps_real_mrd = mrd(target_t) - fmaps_gen_mpd = mpd(pred_t) - fmaps_gen_mrd = mrd(pred_t) + fmaps_real_mpd = mpd(target_t.to(disc_dtype)) + fmaps_real_mrd = mrd(target_t.to(disc_dtype)) + fmaps_gen_mpd = mpd(pred_t.to(disc_dtype)) + fmaps_gen_mrd = mrd(pred_t.to(disc_dtype)) fm_loss = ( _feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) + _feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)