fix: cast input mel to model dtype before vocoder forward pass
mel_converter outputs float32 (cuFFT requirement) but vocoder weights are bfloat16 from model loading. Cast input_mel back to model dtype before feeding the vocoder to avoid conv1d dtype mismatch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1036,7 +1036,7 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
# ~2x compute for a large reduction in activation memory, allowing
|
# ~2x compute for a large reduction in activation memory, allowing
|
||||||
# batch_size > 1 without OOM.
|
# batch_size > 1 without OOM.
|
||||||
pred_wav = torch.utils.checkpoint.checkpoint(
|
pred_wav = torch.utils.checkpoint.checkpoint(
|
||||||
vocoder, input_mel, use_reentrant=False
|
vocoder, input_mel.to(dtype), use_reentrant=False
|
||||||
) # [B, 1, T_wav]
|
) # [B, 1, T_wav]
|
||||||
|
|
||||||
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
T = min(pred_wav.shape[-1], target_wav.shape[-1])
|
||||||
|
|||||||
Reference in New Issue
Block a user