feat: add FA-GAN phase-aware STFT loss to BigVGAN trainer
Adds L1 loss on real, imaginary, and magnitude STFT components across three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing directly — magnitude-only losses cannot distinguish correct spectrum with wrong phase from a smeared spectrum. Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top of both the discriminator FM path and the fallback mel+STFT path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -258,6 +258,30 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
||||
return loss / len(_STFT_RESOLUTIONS)
|
||||
|
||||
|
||||
def _phase_aware_stft_loss(pred_wav, target_wav, device):
|
||||
"""FA-GAN complex STFT loss: L1 on real, imaginary, and magnitude components.
|
||||
|
||||
Penalizes phase smearing directly — plain magnitude loss cannot distinguish
|
||||
a correct spectrum with wrong phase from a smeared spectrum with random phase.
|
||||
Based on FA-GAN (arXiv:2407.04575), applied across three STFT resolutions.
|
||||
inputs: [B, 1, T]
|
||||
"""
|
||||
# cuFFT requires float32 regardless of model dtype
|
||||
pred = pred_wav.squeeze(1).float() # [B, T]
|
||||
target = target_wav.squeeze(1).float()
|
||||
loss = torch.zeros(1, device=device)
|
||||
for n_fft, hop, win in _STFT_RESOLUTIONS:
|
||||
window = torch.hann_window(win, device=device)
|
||||
ps = torch.stft(pred, n_fft, hop, win, window, center=True, return_complex=True)
|
||||
ts = torch.stft(target, n_fft, hop, win, window, center=True, return_complex=True)
|
||||
T = min(ps.shape[-1], ts.shape[-1])
|
||||
ps, ts = ps[..., :T], ts[..., :T]
|
||||
loss = loss + F.l1_loss(ps.real, ts.real)
|
||||
loss = loss + F.l1_loss(ps.imag, ts.imag)
|
||||
loss = loss + F.l1_loss(ps.abs(), ts.abs())
|
||||
return loss / (len(_STFT_RESOLUTIONS) * 3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -318,6 +342,14 @@ class SelvaBigvganTrainer:
|
||||
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
||||
),
|
||||
}),
|
||||
"lambda_phase": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
|
||||
"tooltip": (
|
||||
"FA-GAN phase-aware loss weight. Adds L1 loss on real + imaginary + magnitude "
|
||||
"STFT components, penalizing phase smearing directly. 0 = disabled. "
|
||||
"1.0 is a good starting point alongside other losses."
|
||||
),
|
||||
}),
|
||||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
@@ -335,7 +367,7 @@ class SelvaBigvganTrainer:
|
||||
}
|
||||
|
||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
|
||||
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""):
|
||||
import traceback
|
||||
|
||||
device = get_device()
|
||||
@@ -443,7 +475,7 @@ class SelvaBigvganTrainer:
|
||||
vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -466,7 +498,7 @@ class SelvaBigvganTrainer:
|
||||
def _do_train(vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||
save_every, seed, out_path, disc_path, pbar):
|
||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||
|
||||
@@ -686,6 +718,12 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
primary_loss = mel_loss + stft_loss
|
||||
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
|
||||
|
||||
# ── FA-GAN phase-aware loss (real + imag + mag STFT) ────────────
|
||||
if lambda_phase > 0.0:
|
||||
phase_loss = _phase_aware_stft_loss(pred_t, target_t, device)
|
||||
primary_loss = primary_loss + lambda_phase * phase_loss
|
||||
loss_desc += f" phase={phase_loss.item():.4f}"
|
||||
|
||||
# ── L2-SP regularization ─────────────────────────────────────────
|
||||
l2sp_loss = torch.zeros(1, device=device)
|
||||
if lambda_l2sp > 0.0 and ref_params:
|
||||
|
||||
Reference in New Issue
Block a user