diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 2de8eff..465711f 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1036,7 +1036,7 @@ def _do_train(vocoder, mel_converter, clips, # ~2x compute for a large reduction in activation memory, allowing # batch_size > 1 without OOM. pred_wav = torch.utils.checkpoint.checkpoint( - vocoder, input_mel, use_reentrant=False + vocoder, input_mel.to(dtype), use_reentrant=False ) # [B, 1, T_wav] T = min(pred_wav.shape[-1], target_wav.shape[-1])