diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 5ce63cc..e5c238f 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -182,8 +182,11 @@ class SelvaBigvganTrainer: def _save_sample(label): """Vocode the reference mel and save as .wav.""" try: + # Vocoder may have been offloaded to CPU after training — match its device. + voc_device = next(vocoder.parameters()).device + mel = ref_mel.to(voc_device) with torch.no_grad(): - wav = vocoder(ref_mel) # [1, 1, T] or [1, T] + wav = vocoder(mel) # [1, 1, T] or [1, T] if wav.dim() == 2: wav = wav.unsqueeze(1) wav = wav.float().cpu().clamp(-1, 1)