feat: add FA-GAN phase-aware STFT loss to BigVGAN trainer
Adds L1 loss on real, imaginary, and magnitude STFT components across three resolutions (FA-GAN, arXiv:2407.04575). Penalizes phase smearing directly — magnitude-only losses cannot distinguish correct spectrum with wrong phase from a smeared spectrum. Controlled by lambda_phase (default 1.0, 0 = disabled). Applied on top of both the discriminator FM path and the fallback mel+STFT path. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -258,6 +258,30 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
|||||||
return loss / len(_STFT_RESOLUTIONS)
|
return loss / len(_STFT_RESOLUTIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def _phase_aware_stft_loss(pred_wav, target_wav, device):
|
||||||
|
"""FA-GAN complex STFT loss: L1 on real, imaginary, and magnitude components.
|
||||||
|
|
||||||
|
Penalizes phase smearing directly — plain magnitude loss cannot distinguish
|
||||||
|
a correct spectrum with wrong phase from a smeared spectrum with random phase.
|
||||||
|
Based on FA-GAN (arXiv:2407.04575), applied across three STFT resolutions.
|
||||||
|
inputs: [B, 1, T]
|
||||||
|
"""
|
||||||
|
# cuFFT requires float32 regardless of model dtype
|
||||||
|
pred = pred_wav.squeeze(1).float() # [B, T]
|
||||||
|
target = target_wav.squeeze(1).float()
|
||||||
|
loss = torch.zeros(1, device=device)
|
||||||
|
for n_fft, hop, win in _STFT_RESOLUTIONS:
|
||||||
|
window = torch.hann_window(win, device=device)
|
||||||
|
ps = torch.stft(pred, n_fft, hop, win, window, center=True, return_complex=True)
|
||||||
|
ts = torch.stft(target, n_fft, hop, win, window, center=True, return_complex=True)
|
||||||
|
T = min(ps.shape[-1], ts.shape[-1])
|
||||||
|
ps, ts = ps[..., :T], ts[..., :T]
|
||||||
|
loss = loss + F.l1_loss(ps.real, ts.real)
|
||||||
|
loss = loss + F.l1_loss(ps.imag, ts.imag)
|
||||||
|
loss = loss + F.l1_loss(ps.abs(), ts.abs())
|
||||||
|
return loss / (len(_STFT_RESOLUTIONS) * 3)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Node
|
# Node
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -318,6 +342,14 @@ class SelvaBigvganTrainer:
|
|||||||
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
||||||
),
|
),
|
||||||
}),
|
}),
|
||||||
|
"lambda_phase": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1,
|
||||||
|
"tooltip": (
|
||||||
|
"FA-GAN phase-aware loss weight. Adds L1 loss on real + imaginary + magnitude "
|
||||||
|
"STFT components, penalizing phase smearing directly. 0 = disabled. "
|
||||||
|
"1.0 is a good starting point alongside other losses."
|
||||||
|
),
|
||||||
|
}),
|
||||||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
},
|
},
|
||||||
@@ -335,7 +367,7 @@ class SelvaBigvganTrainer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||||
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
|
segment_seconds, lambda_l2sp, lambda_phase, save_every, seed, discriminator_path=""):
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
@@ -443,7 +475,7 @@ class SelvaBigvganTrainer:
|
|||||||
vocoder, mel_converter, clips,
|
vocoder, mel_converter, clips,
|
||||||
device, dtype, strategy, feature_utils,
|
device, dtype, strategy, feature_utils,
|
||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar,
|
save_every, seed, out_path, disc_path, pbar,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -466,7 +498,7 @@ class SelvaBigvganTrainer:
|
|||||||
def _do_train(vocoder, mel_converter, clips,
|
def _do_train(vocoder, mel_converter, clips,
|
||||||
device, dtype, strategy, feature_utils,
|
device, dtype, strategy, feature_utils,
|
||||||
segment_samples, sample_rate,
|
segment_samples, sample_rate,
|
||||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
train_mode, steps, lr, batch_size, lambda_l2sp, lambda_phase,
|
||||||
save_every, seed, out_path, disc_path, pbar):
|
save_every, seed, out_path, disc_path, pbar):
|
||||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||||
|
|
||||||
@@ -686,6 +718,12 @@ def _do_train(vocoder, mel_converter, clips,
|
|||||||
primary_loss = mel_loss + stft_loss
|
primary_loss = mel_loss + stft_loss
|
||||||
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
|
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
|
||||||
|
|
||||||
|
# ── FA-GAN phase-aware loss (real + imag + mag STFT) ────────────
|
||||||
|
if lambda_phase > 0.0:
|
||||||
|
phase_loss = _phase_aware_stft_loss(pred_t, target_t, device)
|
||||||
|
primary_loss = primary_loss + lambda_phase * phase_loss
|
||||||
|
loss_desc += f" phase={phase_loss.item():.4f}"
|
||||||
|
|
||||||
# ── L2-SP regularization ─────────────────────────────────────────
|
# ── L2-SP regularization ─────────────────────────────────────────
|
||||||
l2sp_loss = torch.zeros(1, device=device)
|
l2sp_loss = torch.zeros(1, device=device)
|
||||||
if lambda_l2sp > 0.0 and ref_params:
|
if lambda_l2sp > 0.0 and ref_params:
|
||||||
|
|||||||
Reference in New Issue
Block a user