feat: add FA-GAN phase-aware STFT loss to BigVGAN trainer

Adds L1 loss on real, imaginary, and magnitude STFT components across
three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing
directly — magnitude-only losses cannot distinguish correct spectrum
with wrong phase from a smeared spectrum.

Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top
of both the discriminator FM path and the fallback mel+STFT path.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 16:09:31 +02:00
parent 82e449681c
commit c53ea5517c
+41 -3
View File
@@ -258,6 +258,30 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
return loss / len(_STFT_RESOLUTIONS)
def _phase_aware_stft_loss(pred_wav, target_wav, device):
"""FA-GAN complex STFT loss: L1 on real, imaginary, and magnitude components.
Penalizes phase smearing directly — plain magnitude loss cannot distinguish
a correct spectrum with wrong phase from a smeared spectrum with random phase.
Based on FA-GAN (arXiv:2407.04575), applied across three STFT resolutions.
inputs: [B, 1, T]
"""
# cuFFT requires float32 regardless of model dtype
pred = pred_wav.squeeze(1).float() # [B, T]
target = target_wav.squeeze(1).float()
loss = torch.zeros(1, device=device)
for n_fft, hop, win in _STFT_RESOLUTIONS:
window = torch.hann_window(win, device=device)
ps = torch.stft(pred, n_fft, hop, win, window, center=True, return_complex=True)
ts = torch.stft(target, n_fft, hop, win, window, center=True, return_complex=True)
T = min(ps.shape[-1], ts.shape[-1])
ps, ts = ps[..., :T], ts[..., :T]
loss = loss + F.l1_loss(ps.real, ts.real)
loss = loss + F.l1_loss(ps.imag, ts.imag)
loss = loss + F.l1_loss(ps.abs(), ts.abs())
return loss / (len(_STFT_RESOLUTIONS) * 3)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
@@ -318,6 +342,14 @@ class SelvaBigvganTrainer:
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
),
}),
"lambda_phase": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
"tooltip": (
"FA-GAN phase-aware loss weight. Adds L1 loss on real + imaginary + magnitude "
"STFT components, penalizing phase smearing directly. 0 = disabled. "
"1.0 is a good starting point alongside other losses."
),
}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
@@ -335,7 +367,7 @@ class SelvaBigvganTrainer:
}
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""):
import traceback
device = get_device()
@@ -443,7 +475,7 @@ class SelvaBigvganTrainer:
vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, steps, lr, batch_size, lambda_l2sp,
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
save_every, seed, out_path, disc_path, pbar,
)
except Exception as e:
@@ -466,7 +498,7 @@ class SelvaBigvganTrainer:
def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, steps, lr, batch_size, lambda_l2sp,
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
save_every, seed, out_path, disc_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active.
@@ -686,6 +718,12 @@ def _do_train(vocoder, mel_converter, clips,
primary_loss = mel_loss + stft_loss
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
# ── FA-GAN phase-aware loss (real + imag + mag STFT) ────────────
if lambda_phase > 0.0:
phase_loss = _phase_aware_stft_loss(pred_t, target_t, device)
primary_loss = primary_loss + lambda_phase * phase_loss
loss_desc += f" phase={phase_loss.item():.4f}"
# ── L2-SP regularization ─────────────────────────────────────────
l2sp_loss = torch.zeros(1, device=device)
if lambda_l2sp > 0.0 and ref_params: