fix: cast all STFT inputs to float32 to prevent cuFFT bfloat16 crash
cuFFT does not support bfloat16 tensors. When the model is loaded in bfloat16, all torch.stft calls (mel_converter, discriminator spectrogram, multi-resolution STFT loss) crash. Add .float() at every STFT boundary. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -123,7 +123,7 @@ class _DiscriminatorR(nn.Module):
|
|||||||
x = x.squeeze(1) # [B, T]
|
x = x.squeeze(1) # [B, T]
|
||||||
pad = (win - hop) // 2
|
pad = (win - hop) // 2
|
||||||
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
|
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
|
||||||
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
|
x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True)
|
||||||
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
|
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -321,8 +321,9 @@ def _stft_mag(wav, n_fft, hop_length, win_length, device):
|
|||||||
|
|
||||||
def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
||||||
"""Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]"""
|
"""Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]"""
|
||||||
pred = pred_wav.squeeze(1) # [B, T]
|
# cuFFT requires float32 regardless of model dtype
|
||||||
target = target_wav.squeeze(1)
|
pred = pred_wav.squeeze(1).float() # [B, T]
|
||||||
|
target = target_wav.squeeze(1).float()
|
||||||
loss = torch.zeros(1, device=device)
|
loss = torch.zeros(1, device=device)
|
||||||
for n_fft, hop, win in _STFT_RESOLUTIONS:
|
for n_fft, hop, win in _STFT_RESOLUTIONS:
|
||||||
pm = _stft_mag(pred, n_fft, hop, win, device)
|
pm = _stft_mag(pred, n_fft, hop, win, device)
|
||||||
@@ -818,8 +819,8 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
# Reference segment for eval samples — always clip 0, full length
|
# Reference segment for eval samples — always clip 0, full length
|
||||||
ref_wav = clips[0].to(device, dtype) # full first clip [T]
|
ref_wav = clips[0].to(device) # full first clip [T]
|
||||||
ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel]
|
ref_mel = mel_converter(ref_wav.float().unsqueeze(0)) # [1, n_mels, T_mel] (cuFFT needs float32)
|
||||||
|
|
||||||
# Ground-truth spectrogram — saved once alongside baseline for comparison
|
# Ground-truth spectrogram — saved once alongside baseline for comparison
|
||||||
gt_spec_path = out_path.parent / f"{out_path.stem}_gt_spec.png"
|
gt_spec_path = out_path.parent / f"{out_path.stem}_gt_spec.png"
|
||||||
@@ -1017,11 +1018,11 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
target_wav = target_flat.unsqueeze(1) # [B, 1, T]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
input_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] (cuFFT needs float32)
|
||||||
|
|
||||||
# Clean target mel for mel loss (always from clean audio)
|
# Clean target mel for mel loss (always from clean audio)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
target_mel = mel_converter(target_flat) # [B, n_mels, T_mel]
|
target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel]
|
||||||
|
|
||||||
# Gradient checkpointing: recompute BigVGAN activations during
|
# Gradient checkpointing: recompute BigVGAN activations during
|
||||||
# backward instead of storing them. The 512x upsampling stack
|
# backward instead of storing them. The 512x upsampling stack
|
||||||
@@ -1049,14 +1050,14 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
|
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
|
||||||
)
|
)
|
||||||
# Keep a small mel loss for stable frequency alignment
|
# Keep a small mel loss for stable frequency alignment
|
||||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
pred_mel = mel_converter(pred_t.squeeze(1).float())
|
||||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||||
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
|
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
|
||||||
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
|
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
|
||||||
else:
|
else:
|
||||||
# Fallback: mel L1 + multi-resolution STFT L1
|
# Fallback: mel L1 + multi-resolution STFT L1
|
||||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
pred_mel = mel_converter(pred_t.squeeze(1).float())
|
||||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||||
|
|||||||
Reference in New Issue
Block a user