diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 25e096a..4130a17 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -954,8 +954,8 @@ def _do_train(vocoder, mel_converter, clips, f"MPD loaded={mpd_loaded}, MRD loaded={mrd_loaded}. " f"Available keys: {list(ckpt_d.keys())}" ) - mpd.to(device).eval() - mrd.to(device).eval() + mpd.to(device, dtype).eval() + mrd.to(device, dtype).eval() for p in mpd.parameters(): p.requires_grad_(False) for p in mrd.parameters():