fix: cast mel_converter buffers to float32 to match STFT input dtype

mel_basis and hann_window buffers inherit bfloat16 from model loading.
Since all mel_converter inputs are cast to float32 for cuFFT, the
internal buffers must also be float32 to avoid matmul dtype mismatch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 00:10:52 +02:00
parent bee518a855
commit d06936802b
+3 -1
View File
@@ -808,9 +808,11 @@ def _do_train(vocoder, mel_converter, clips,
clips = [c.clone() for c in clips] clips = [c.clone() for c in clips]
# 2. mel_converter buffers (mel_basis, hann_window) — same origin. # 2. mel_converter buffers (mel_basis, hann_window) — same origin.
# Also cast to float32: mel_converter receives float32 audio (cuFFT
# requirement) so all internal buffers must match.
for name, buf in list(mel_converter._buffers.items()): for name, buf in list(mel_converter._buffers.items()):
if buf is not None: if buf is not None:
mel_converter._buffers[name] = buf.clone() mel_converter._buffers[name] = buf.clone().float()
# 3. Vocoder parameters are handled below with clone().detach(). # 3. Vocoder parameters are handled below with clone().detach().
# ───────────────────────────────────────────────────────────────────────── # ─────────────────────────────────────────────────────────────────────────