fix: cast all STFT inputs to float32 to prevent cuFFT bfloat16 crash

cuFFT does not support bfloat16 tensors. When the model is loaded in
bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram,
multi-resolution STFT loss) crash. Add .float() at every STFT boundary.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 23:53:36 +02:00
parent 48b72c0be0
commit bee518a855
+10 -9
View File
@@ -123,7 +123,7 @@ class _DiscriminatorR(nn.Module):
x = x.squeeze(1) # [B, T]
pad = (win - hop) // 2
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True)
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
return x
@@ -321,8 +321,9 @@ def _stft_mag(wav, n_fft, hop_length, win_length, device):
def _multi_resolution_stft_loss(pred_wav, target_wav, device):
"""Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]"""
pred = pred_wav.squeeze(1) # [B, T]
target = target_wav.squeeze(1)
# cuFFT requires float32 regardless of model dtype
pred = pred_wav.squeeze(1).float() # [B, T]
target = target_wav.squeeze(1).float()
loss = torch.zeros(1, device=device)
for n_fft, hop, win in _STFT_RESOLUTIONS:
pm = _stft_mag(pred, n_fft, hop, win, device)
@@ -818,8 +819,8 @@ def _do_train(vocoder, mel_converter, clips,
random.seed(seed)
# Reference segment for eval samples — always clip 0, full length
ref_wav = clips[0].to(device, dtype) # full first clip [T]
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
ref_wav = clips[0].to(device) # full first clip [T]
ref_mel = mel_converter(ref_wav.float().unsqueeze(0)) # [1, n_mels, T_mel] (cuFFT needs float32)
# Ground-truth spectrogram — saved once alongside baseline for comparison
gt_spec_path = out_path.parent / f"{out_path.stem}_gt_spec.png"
@@ -1017,11 +1018,11 @@ def _do_train(vocoder, mel_converter, clips,
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
with torch.no_grad():
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
input_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] (cuFFT needs float32)
# Clean target mel for mel loss (always from clean audio)
with torch.no_grad():
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
# Gradient checkpointing: recompute BigVGAN activations during
# backward instead of storing them. The 512x upsampling stack
@@ -1049,14 +1050,14 @@ def _do_train(vocoder, mel_converter, clips,
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
)
# Keep a small mel loss for stable frequency alignment
pred_mel = mel_converter(pred_t.squeeze(1))
pred_mel = mel_converter(pred_t.squeeze(1).float())
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
else:
# Fallback: mel L1 + multi-resolution STFT L1
pred_mel = mel_converter(pred_t.squeeze(1))
pred_mel = mel_converter(pred_t.squeeze(1).float())
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)