fix: match mel dtype to vocoder in baseline sample generation

ref_mel is float32 (from mel_converter) but vocoder weights are bfloat16
before inference flag stripping. Cast mel to vocoder's dtype to prevent
input/bias type mismatch during baseline sample save.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:45:31 +02:00
parent cb9a1eef01
commit 37a27160aa
+2 -2
View File
@@ -915,8 +915,8 @@ def _do_train(vocoder, mel_converter, clips,
def _save_sample(label):
try:
voc_device = next(vocoder.parameters()).device
mel = ref_mel.to(voc_device)
voc_p = next(vocoder.parameters())
mel = ref_mel.to(voc_p.device, voc_p.dtype)
with torch.no_grad():
wav = vocoder(mel)
if wav.dim() == 2: