diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index bc6074a..f66d2cf 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -915,8 +915,8 @@ def _do_train(vocoder, mel_converter, clips, def _save_sample(label): try: - voc_device = next(vocoder.parameters()).device - mel = ref_mel.to(voc_device) + voc_p = next(vocoder.parameters()) + mel = ref_mel.to(voc_p.device, voc_p.dtype) with torch.no_grad(): wav = vocoder(mel) if wav.dim() == 2: