fix: cast mel_converter buffers to float32 to match STFT input dtype
mel_basis and hann_window buffers inherit bfloat16 from model loading. Since all mel_converter inputs are cast to float32 for cuFFT, the internal buffers must also be float32 to avoid matmul dtype mismatch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -808,9 +808,11 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
clips = [c.clone() for c in clips]
|
||||
|
||||
# 2. mel_converter buffers (mel_basis, hann_window) — same origin.
|
||||
# Also cast to float32: mel_converter receives float32 audio (cuFFT
|
||||
# requirement) so all internal buffers must match.
|
||||
for name, buf in list(mel_converter._buffers.items()):
|
||||
if buf is not None:
|
||||
mel_converter._buffers[name] = buf.clone()
|
||||
mel_converter._buffers[name] = buf.clone().float()
|
||||
|
||||
# 3. Vocoder parameters are handled below with clone().detach().
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user