diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 14756d1..158c514 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -258,6 +258,30 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device): return loss / len(_STFT_RESOLUTIONS) +def _phase_aware_stft_loss(pred_wav, target_wav, device): + """FA-GAN complex STFT loss: L1 on real, imaginary, and magnitude components. + + Penalizes phase smearing directly — plain magnitude loss cannot distinguish + a correct spectrum with wrong phase from a smeared spectrum with random phase. + Based on FA-GAN (arXiv:2407.04575), applied across three STFT resolutions. + inputs: [B, 1, T] + """ + # cuFFT requires float32 regardless of model dtype + pred = pred_wav.squeeze(1).float() # [B, T] + target = target_wav.squeeze(1).float() + loss = torch.zeros(1, device=device) + for n_fft, hop, win in _STFT_RESOLUTIONS: + window = torch.hann_window(win, device=device) + ps = torch.stft(pred, n_fft, hop, win, window, center=True, return_complex=True) + ts = torch.stft(target, n_fft, hop, win, window, center=True, return_complex=True) + T = min(ps.shape[-1], ts.shape[-1]) + ps, ts = ps[..., :T], ts[..., :T] + loss = loss + F.l1_loss(ps.real, ts.real) + loss = loss + F.l1_loss(ps.imag, ts.imag) + loss = loss + F.l1_loss(ps.abs(), ts.abs()) + return loss / (len(_STFT_RESOLUTIONS) * 3) + + # --------------------------------------------------------------------------- # Node # --------------------------------------------------------------------------- @@ -318,6 +342,14 @@ class SelvaBigvganTrainer: "Increase to 1e-2 for all_params to prevent catastrophic forgetting." ), }), + "lambda_phase": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1, + "tooltip": ( + "FA-GAN phase-aware loss weight. Adds L1 loss on real + imaginary + magnitude " + "STFT components, penalizing phase smearing directly. 0 = disabled. " + "1.0 is a good starting point alongside other losses." + ), + }), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), }, @@ -335,7 +367,7 @@ class SelvaBigvganTrainer: } def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size, - segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""): + segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""): import traceback device = get_device() @@ -443,7 +475,7 @@ class SelvaBigvganTrainer: vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - train_mode, steps, lr, batch_size, lambda_l2sp, + train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, save_every, seed, out_path, disc_path, pbar, ) except Exception as e: @@ -466,7 +498,7 @@ class SelvaBigvganTrainer: def _do_train(vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - train_mode, steps, lr, batch_size, lambda_l2sp, + train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase, save_every, seed, out_path, disc_path, pbar): """Execute training. Called in a fresh thread — no inference_mode active. @@ -686,6 +718,12 @@ def _do_train(vocoder, mel_converter, clips, primary_loss = mel_loss + stft_loss loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}" + # ── FA-GAN phase-aware loss (real + imag + mag STFT) ──────────── + if lambda_phase > 0.0: + phase_loss = _phase_aware_stft_loss(pred_t, target_t, device) + primary_loss = primary_loss + lambda_phase * phase_loss + loss_desc += f" phase={phase_loss.item():.4f}" + # ── L2-SP regularization ───────────────────────────────────────── l2sp_loss = torch.zeros(1, device=device) if lambda_l2sp > 0.0 and ref_params: