From 608746ce7b136302c09fd73de9e736218a73985d Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Fri, 10 Apr 2026 00:18:05 +0200 Subject: [PATCH] fix: cast input mel to model dtype before vocoder forward pass mel_converter outputs float32 (cuFFT requirement) but vocoder weights are bfloat16 from model loading. Cast input_mel back to model dtype before feeding the vocoder to avoid conv1d dtype mismatch. Co-Authored-By: Claude Opus 4.6 --- nodes/selva_bigvgan_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 2de8eff..465711f 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1036,7 +1036,7 @@ def _do_train(vocoder, mel_converter, clips, # ~2x compute for a large reduction in activation memory, allowing # batch_size > 1 without OOM. pred_wav = torch.utils.checkpoint.checkpoint( - vocoder, input_mel, use_reentrant=False + vocoder, input_mel.to(dtype), use_reentrant=False ) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1])