diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 629d0f4..b73c2e9 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -123,7 +123,7 @@ class _DiscriminatorR(nn.Module): x = x.squeeze(1) # [B, T] pad = (win - hop) // 2 x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect") - x = torch.stft(x, n, hop, win, window, center=False, return_complex=True) + x = torch.stft(x.float(), n, hop, win, window, center=False, return_complex=True) x = x.abs().unsqueeze(1) # [B, 1, freq, time] return x @@ -321,8 +321,9 @@ def _stft_mag(wav, n_fft, hop_length, win_length, device): def _multi_resolution_stft_loss(pred_wav, target_wav, device): """Average L1 mag loss across three STFT resolutions. inputs: [B, 1, T]""" - pred = pred_wav.squeeze(1) # [B, T] - target = target_wav.squeeze(1) + # cuFFT requires float32 regardless of model dtype + pred = pred_wav.squeeze(1).float() # [B, T] + target = target_wav.squeeze(1).float() loss = torch.zeros(1, device=device) for n_fft, hop, win in _STFT_RESOLUTIONS: pm = _stft_mag(pred, n_fft, hop, win, device) @@ -818,8 +819,8 @@ def _do_train(vocoder, mel_converter, clips, random.seed(seed) # Reference segment for eval samples — always clip 0, full length - ref_wav = clips[0].to(device, dtype) # full first clip [T] - ref_mel = mel_converter(ref_wav.unsqueeze(0)) # [1, n_mels, T_mel] + ref_wav = clips[0].to(device) # full first clip [T] + ref_mel = mel_converter(ref_wav.float().unsqueeze(0)) # [1, n_mels, T_mel] (cuFFT needs float32) # Ground-truth spectrogram — saved once alongside baseline for comparison gt_spec_path = out_path.parent / f"{out_path.stem}_gt_spec.png" @@ -1017,11 +1018,11 @@ def _do_train(vocoder, mel_converter, clips, target_wav = target_flat.unsqueeze(1) # [B, 1, T] with torch.no_grad(): - input_mel = mel_converter(target_flat) # [B, n_mels, T_mel] + input_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] (cuFFT needs float32) # Clean target mel for mel loss (always from clean audio) with torch.no_grad(): - target_mel = mel_converter(target_flat) # [B, n_mels, T_mel] + target_mel = mel_converter(target_flat.float()) # [B, n_mels, T_mel] # Gradient checkpointing: recompute BigVGAN activations during # backward instead of storing them. The 512x upsampling stack @@ -1049,14 +1050,14 @@ def _do_train(vocoder, mel_converter, clips, _feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd) ) # Keep a small mel loss for stable frequency alignment - pred_mel = mel_converter(pred_t.squeeze(1)) + pred_mel = mel_converter(pred_t.squeeze(1).float()) T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) primary_loss = 2.0 * fm_loss + 0.1 * mel_loss loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}" else: # Fallback: mel L1 + multi-resolution STFT L1 - pred_mel = mel_converter(pred_t.squeeze(1)) + pred_mel = mel_converter(pred_t.squeeze(1).float()) T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)