From c28e090196c4157917f7adb70c8c70e8f21467cf Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 11:36:02 +0200 Subject: [PATCH] fix: cast discriminator inputs to match bfloat16 dtype in BigVGAN FM loss The frozen discriminators are loaded in model dtype (bfloat16) but vocoder waveform outputs are float32, causing a Conv2d dtype mismatch. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 01853bc..965d3ad 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1161,11 +1161,13 @@ def _do_train(vocoder, mel_converter, clips, if mpd is not None and mrd is not None: # Perceptual feature matching via frozen discriminators + # Discriminators are in model dtype (bfloat16); waveforms are float32 + disc_dtype = next(mpd.parameters()).dtype with torch.no_grad(): - fmaps_real_mpd = mpd(target_t) - fmaps_real_mrd = mrd(target_t) - fmaps_gen_mpd = mpd(pred_t) - fmaps_gen_mrd = mrd(pred_t) + fmaps_real_mpd = mpd(target_t.to(disc_dtype)) + fmaps_real_mrd = mrd(target_t.to(disc_dtype)) + fmaps_gen_mpd = mpd(pred_t.to(disc_dtype)) + fmaps_gen_mrd = mrd(pred_t.to(disc_dtype)) fm_loss = ( _feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) + _feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)