fix: cast all STFT inputs to float32 to prevent cuFFT bfloat16 crash
cuFFT does not support bfloat16 tensors. When the model is loaded in bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram, multi-resolution STFT loss) crash. Add .float() at every STFT boundary. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -123,7 +123,7 @@ class _DiscriminatorR(nn.Module):
|
||||
x = x.squeeze(1) # [B, T]
|
||||
pad = (win - hop) // 2
|
||||
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
|
||||
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
|
||||
x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True)
|
||||
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
|
||||
return x
|
||||
|
||||
@@ -321,8 +321,9 @@ def _stft_mag(wav, n_fft, hop_length, win_length, device):
|
||||
|
||||
def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
||||
"""Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]"""
|
||||
pred = pred_wav.squeeze(1) # [B, T]
|
||||
target = target_wav.squeeze(1)
|
||||
# cuFFT requires float32 regardless of model dtype
|
||||
pred = pred_wav.squeeze(1).float() # [B, T]
|
||||
target = target_wav.squeeze(1).float()
|
||||
loss = torch.zeros(1, device=device)
|
||||
for n_fft, hop, win in _STFT_RESOLUTIONS:
|
||||
pm = _stft_mag(pred, n_fft, hop, win, device)
|
||||
@@ -818,8 +819,8 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
random.seed(seed)
|
||||
|
||||
# Reference segment for eval samples — always clip 0, full length
|
||||
ref_wav = clips[0].to(device, dtype) # full first clip [T]
|
||||
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
||||
ref_wav = clips[0].to(device) # full first clip [T]
|
||||
ref_mel = mel_converter(ref_wav.float().unsqueeze(0)) # [1, n_mels, T_mel] (cuFFT needs float32)
|
||||
|
||||
# Ground-truth spectrogram — saved once alongside baseline for comparison
|
||||
gt_spec_path = out_path.parent / f"{out_path.stem}_gt_spec.png"
|
||||
@@ -1017,11 +1018,11 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||
|
||||
with torch.no_grad():
|
||||
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
input_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] (cuFFT needs float32)
|
||||
|
||||
# Clean target mel for mel loss (always from clean audio)
|
||||
with torch.no_grad():
|
||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
||||
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
|
||||
|
||||
# Gradient checkpointing: recompute BigVGAN activations during
|
||||
# backward instead of storing them. The 512x upsampling stack
|
||||
@@ -1049,14 +1050,14 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
|
||||
)
|
||||
# Keep a small mel loss for stable frequency alignment
|
||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
||||
pred_mel = mel_converter(pred_t.squeeze(1).float())
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
|
||||
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
|
||||
else:
|
||||
# Fallback: mel L1 + multi-resolution STFT L1
|
||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
||||
pred_mel = mel_converter(pred_t.squeeze(1).float())
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||
|
||||
Reference in New Issue
Block a user