fix: cast input mel to model dtype before vocoder forward pass

mel_converter outputs float32 (cuFFT requirement) but vocoder weights are
bfloat16 from model loading. Cast input_mel back to model dtype before
feeding the vocoder to avoid conv1d dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:18:05 +02:00
parent bba5aec7a5
commit 608746ce7b
+1 -1
View File
@@ -1036,7 +1036,7 @@ def _do_train(vocoder, mel_converter, clips,
# ~2x compute for a large reduction in activation memory, allowing # ~2x compute for a large reduction in activation memory, allowing
# batch_size > 1 without OOM. # batch_size > 1 without OOM.
pred_wav = torch.utils.checkpoint.checkpoint( pred_wav = torch.utils.checkpoint.checkpoint(
vocoder, input_mel, use_reentrant=False vocoder, input_mel.to(dtype), use_reentrant=False
) # [B, 1, T_wav] ) # [B, 1, T_wav]
T = min(pred_wav.shape[-1], target_wav.shape[-1]) T = min(pred_wav.shape[-1], target_wav.shape[-1])