feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes
BigVGAN trainer (selva_bigvgan_trainer.py): - Add snake_alpha_only train mode: tunes only ~27K per-channel α params (0.024% of 112M) — physically cannot cause harmonic smearing - Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights - Add optional discriminator_path: frozen MPD+MRD feature matching loss replaces mel L1 when a BigVGAN discriminator checkpoint is provided - Inline MPD + MRD discriminator implementations (no extra dependencies) DITTO optimizer (selva_ditto_optimizer.py): - New node: inference-time noise optimization (arXiv:2401.12179) - Optimizes x₀ via mel Gram matrix style loss against BJ reference clips - All model weights frozen — zero quality degradation risk - Truncated BPTT through last n_grad_steps of the ODE (configurable) - Gradient checkpointing on each differentiated step Docs: - README: document all 20 nodes (was 3), add workflow diagrams - STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers, why LoRA/TI fail, combined approach, dataset prep Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+312
-36
@@ -1,14 +1,29 @@
|
||||
"""SelVA BigVGAN Vocoder Fine-tuner.
|
||||
|
||||
Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using
|
||||
spectral reconstruction losses. The DiT and VAE are completely untouched.
|
||||
Tier-1 approach based on research: snake alpha fine-tuning + L2-SP anchor
|
||||
regularization + optional frozen discriminator feature matching.
|
||||
|
||||
Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1.
|
||||
No GAN discriminator — this is a proof-of-concept to verify that the vocoder
|
||||
can absorb BJ timbral characteristics before investing in full adversarial training.
|
||||
Root cause of harmonic smearing with plain mel/STFT losses:
|
||||
Spectral L1 minimizes expected reconstruction error — averaging over
|
||||
high-variance harmonics. This is a loss-function topology problem, not
|
||||
an LR/step-count problem. The fix is either (a) restrict trainable params
|
||||
so the model lacks capacity to smear, or (b) use a perceptual loss that
|
||||
penalizes harmonic averaging.
|
||||
|
||||
Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN
|
||||
checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||
Tier-1 implementation:
|
||||
1. snake_alpha_only mode — only tune ~5K per-channel α parameters in
|
||||
Snake/SnakeBeta activations. These control harmonic periodicity per
|
||||
channel. With only 5K trainable params, the model physically cannot
|
||||
reshape the spectrum enough to cause the "green smear".
|
||||
2. L2-SP anchor loss — penalizes parameter drift from pretrained values
|
||||
(strictly better than weight decay, which anchors to zero).
|
||||
3. Frozen discriminator feature matching — if a BigVGAN discriminator
|
||||
checkpoint is provided, the pretrained MPD+MRD networks are used as
|
||||
fixed perceptual feature extractors. Feature matching loss penalizes
|
||||
harmonic smearing directly without any GAN instability.
|
||||
|
||||
Save format: {'generator': vocoder.state_dict()} — same as the original
|
||||
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
|
||||
"""
|
||||
|
||||
import random
|
||||
@@ -16,6 +31,7 @@ import threading
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import comfy.utils
|
||||
@@ -23,12 +39,133 @@ import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||
|
||||
def _save_spectrogram(path, mel_tensor):
|
||||
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
|
||||
|
||||
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom,
|
||||
and saves as a greyscale PNG with a simple viridis-like colour map.
|
||||
"""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal MPD + MRD discriminators matching BigVGAN pretrained checkpoint keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_pad(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class _DiscriminatorP(nn.Module):
|
||||
"""Multi-Period Discriminator sub-module (HiFi-GAN / BigVGAN style)."""
|
||||
def __init__(self, period):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
norm = weight_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm(nn.Conv2d(1, 32, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
|
||||
norm(nn.Conv2d(32, 128, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
|
||||
norm(nn.Conv2d(128, 512, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
|
||||
norm(nn.Conv2d(512, 1024,(5, 1), (3, 1), (_get_pad(5, 1), 0))),
|
||||
norm(nn.Conv2d(1024,1024,(5, 1), 1, (_get_pad(5, 1), 0))),
|
||||
])
|
||||
self.conv_post = norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0:
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
return fmap
|
||||
|
||||
|
||||
class _MultiPeriodDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
_DiscriminatorP(p) for p in [2, 3, 5, 7, 11]
|
||||
])
|
||||
|
||||
def forward(self, y):
|
||||
fmaps = []
|
||||
for d in self.discriminators:
|
||||
fmaps.extend(d(y))
|
||||
return fmaps
|
||||
|
||||
|
||||
class _DiscriminatorR(nn.Module):
|
||||
"""Multi-Resolution Discriminator sub-module."""
|
||||
def __init__(self, fft_size, shift_size, win_length):
|
||||
super().__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
norm = weight_norm
|
||||
self.convs = nn.ModuleList([
|
||||
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
||||
])
|
||||
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x):
|
||||
"""x: [B, 1, T] → [B, 1, freq, time]"""
|
||||
n, hop, win = self.fft_size, self.shift_size, self.win_length
|
||||
window = torch.hann_window(win, device=x.device)
|
||||
x = x.squeeze(1) # [B, T]
|
||||
pad = (win - hop) // 2
|
||||
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
|
||||
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
|
||||
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
x = self.spectrogram(x)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
return fmap
|
||||
|
||||
|
||||
class _MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
resolutions = [(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]
|
||||
self.discriminators = nn.ModuleList([
|
||||
_DiscriminatorR(*r) for r in resolutions
|
||||
])
|
||||
|
||||
def forward(self, y):
|
||||
fmaps = []
|
||||
for d in self.discriminators:
|
||||
fmaps.extend(d(y))
|
||||
return fmaps
|
||||
|
||||
|
||||
def _feature_matching_loss(fmaps_real, fmaps_gen):
|
||||
"""L1 between paired feature map lists (both already detach-safe for real)."""
|
||||
loss = torch.zeros(1, device=fmaps_gen[0].device)
|
||||
for fr, fg in zip(fmaps_real, fmaps_gen):
|
||||
T = min(fr.shape[-1], fg.shape[-1])
|
||||
loss = loss + F.l1_loss(fg[..., :T], fr[..., :T].detach())
|
||||
return loss / len(fmaps_real)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _save_spectrogram(path, mel_tensor):
|
||||
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep)."""
|
||||
try:
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
@@ -120,6 +257,10 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
|
||||
return loss / len(_STFT_RESOLUTIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SelvaBigvganTrainer:
|
||||
OUTPUT_NODE = True
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
@@ -128,10 +269,10 @@ class SelvaBigvganTrainer:
|
||||
RETURN_NAMES = ("checkpoint_path",)
|
||||
OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",)
|
||||
DESCRIPTION = (
|
||||
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
|
||||
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
|
||||
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. "
|
||||
"Load the result with SelVA BigVGAN Loader."
|
||||
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips. "
|
||||
"Default mode (snake_alpha_only) tunes only the ~5K Snake activation α "
|
||||
"parameters — cannot cause harmonic smearing. Add a discriminator path "
|
||||
"for perceptual feature matching loss. DiT and VAE stay frozen."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -147,26 +288,53 @@ class SelvaBigvganTrainer:
|
||||
"default": "bigvgan_bj.pt",
|
||||
"tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.",
|
||||
}),
|
||||
"train_mode": (["snake_alpha_only", "all_params"], {
|
||||
"default": "snake_alpha_only",
|
||||
"tooltip": (
|
||||
"snake_alpha_only: only tune ~5K per-channel α parameters in Snake/SnakeBeta "
|
||||
"activations. These control harmonic periodicity. Cannot cause spectral smearing. "
|
||||
"all_params: tune all vocoder weights — set lambda_l2sp>0 to prevent drift."
|
||||
),
|
||||
}),
|
||||
"steps": ("INT", {
|
||||
"default": 2000, "min": 100, "max": 50000,
|
||||
"tooltip": "Training steps. 1000–2000 is a good first experiment.",
|
||||
"tooltip": "Training steps. 1000–2000 is a good first experiment with snake_alpha_only.",
|
||||
}),
|
||||
"lr": ("FLOAT", {
|
||||
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5,
|
||||
"tooltip": "Learning rate. BigVGAN default is 1e-4.",
|
||||
"tooltip": "Learning rate. 1e-4 for snake_alpha_only, 1e-5 for all_params.",
|
||||
}),
|
||||
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32}),
|
||||
"segment_seconds": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25,
|
||||
"tooltip": "Audio segment length per training sample in seconds.",
|
||||
}),
|
||||
"lambda_l2sp": ("FLOAT", {
|
||||
"default": 1e-3, "min": 0.0, "max": 0.1, "step": 1e-4,
|
||||
"tooltip": (
|
||||
"L2-SP anchor regularization: penalizes parameter drift from pretrained values. "
|
||||
"0 = disabled. 1e-3 is good for snake_alpha_only. "
|
||||
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
|
||||
),
|
||||
}),
|
||||
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
"optional": {
|
||||
"discriminator_path": ("STRING", {
|
||||
"default": "",
|
||||
"tooltip": (
|
||||
"Optional path to BigVGAN discriminator checkpoint "
|
||||
"(bigvgan_discriminator_optimizer.pt from the BigVGAN pretrained release). "
|
||||
"When provided, frozen MPD+MRD feature matching replaces mel L1 — "
|
||||
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
|
||||
),
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
def train(self, model, data_dir, output_path, steps, lr, batch_size,
|
||||
segment_seconds, save_every, seed):
|
||||
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
|
||||
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
|
||||
import traceback
|
||||
|
||||
device = get_device()
|
||||
@@ -197,6 +365,14 @@ class SelvaBigvganTrainer:
|
||||
out_path = Path(folder_paths.get_output_directory()) / out_path
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
disc_path = None
|
||||
if discriminator_path and discriminator_path.strip():
|
||||
disc_path = Path(discriminator_path.strip())
|
||||
if not disc_path.is_absolute():
|
||||
disc_path = Path(folder_paths.get_output_directory()) / disc_path
|
||||
if not disc_path.exists():
|
||||
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
|
||||
|
||||
# Find and pre-load audio clips
|
||||
segment_samples = int(segment_seconds * sample_rate)
|
||||
audio_files = []
|
||||
@@ -227,8 +403,15 @@ class SelvaBigvganTrainer:
|
||||
raise RuntimeError(
|
||||
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
|
||||
)
|
||||
print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s "
|
||||
f"steps={steps} lr={lr} batch={batch_size}\n", flush=True)
|
||||
|
||||
trainable_count = sum(
|
||||
1 for n, _ in vocoder.named_parameters() if "alpha" in n
|
||||
) if train_mode == "snake_alpha_only" else sum(
|
||||
1 for _ in vocoder.parameters()
|
||||
)
|
||||
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
|
||||
f"segment={segment_seconds}s steps={steps} lr={lr} "
|
||||
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(device)
|
||||
@@ -259,8 +442,8 @@ class SelvaBigvganTrainer:
|
||||
vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
steps, lr, batch_size, save_every, seed,
|
||||
out_path, pbar,
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
save_every, seed, out_path, disc_path, pbar,
|
||||
)
|
||||
except Exception as e:
|
||||
_exc[0] = e
|
||||
@@ -275,11 +458,15 @@ class SelvaBigvganTrainer:
|
||||
return (_result[0],)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training worker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _do_train(vocoder, mel_converter, clips,
|
||||
device, dtype, strategy, feature_utils,
|
||||
segment_samples, sample_rate,
|
||||
steps, lr, batch_size, save_every, seed,
|
||||
out_path, pbar):
|
||||
train_mode, steps, lr, batch_size, lambda_l2sp,
|
||||
save_every, seed, out_path, disc_path, pbar):
|
||||
"""Execute training. Called in a fresh thread — no inference_mode active.
|
||||
|
||||
Even though inference_mode is off here, tensors created in the calling
|
||||
@@ -372,7 +559,65 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
if buf is not None:
|
||||
module._buffers[bname] = buf.clone()
|
||||
|
||||
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
|
||||
# ── Training mode: select which parameters to train ──────────────────────
|
||||
if train_mode == "snake_alpha_only":
|
||||
alpha_params = []
|
||||
for name, param in vocoder.named_parameters():
|
||||
if "alpha" in name:
|
||||
param.requires_grad_(True)
|
||||
alpha_params.append(param)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
n_trainable = sum(p.numel() for p in alpha_params)
|
||||
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
|
||||
f"({len(alpha_params)} alpha tensors)", flush=True)
|
||||
trainable_params = alpha_params
|
||||
else: # all_params
|
||||
for param in vocoder.parameters():
|
||||
param.requires_grad_(True)
|
||||
n_trainable = sum(p.numel() for p in vocoder.parameters())
|
||||
print(f"[BigVGAN] all_params: {n_trainable} trainable params", flush=True)
|
||||
trainable_params = list(vocoder.parameters())
|
||||
|
||||
# ── L2-SP: cache reference parameter values (before any gradient steps) ──
|
||||
ref_params = {}
|
||||
if lambda_l2sp > 0.0:
|
||||
for name, param in vocoder.named_parameters():
|
||||
if param.requires_grad:
|
||||
ref_params[name] = param.data.clone().detach()
|
||||
print(f"[BigVGAN] L2-SP anchor: {len(ref_params)} params λ={lambda_l2sp}", flush=True)
|
||||
|
||||
# ── Optional: load pretrained discriminator for feature matching ──────────
|
||||
mpd = mrd = None
|
||||
if disc_path is not None:
|
||||
try:
|
||||
ckpt_d = torch.load(str(disc_path), map_location="cpu", weights_only=False)
|
||||
mpd = _MultiPeriodDiscriminator()
|
||||
mrd = _MultiResolutionDiscriminator()
|
||||
# Try common key names used by different BigVGAN releases
|
||||
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
|
||||
if mpd_key in ckpt_d:
|
||||
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
|
||||
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
|
||||
break
|
||||
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
|
||||
if mrd_key in ckpt_d:
|
||||
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
|
||||
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
|
||||
break
|
||||
mpd.to(device).eval()
|
||||
mrd.to(device).eval()
|
||||
for p in mpd.parameters():
|
||||
p.requires_grad_(False)
|
||||
for p in mrd.parameters():
|
||||
p.requires_grad_(False)
|
||||
print(f"[BigVGAN] Frozen discriminators ready for feature matching", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[BigVGAN] WARNING: Could not load discriminator ({e}), "
|
||||
f"falling back to mel+STFT losses", flush=True)
|
||||
mpd = mrd = None
|
||||
|
||||
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
|
||||
vocoder.train()
|
||||
|
||||
try:
|
||||
@@ -396,24 +641,55 @@ def _do_train(vocoder, mel_converter, clips,
|
||||
pred_t = pred_wav[..., :T]
|
||||
target_t = target_wav[..., :T]
|
||||
|
||||
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
# ── Compute loss ─────────────────────────────────────────────────
|
||||
if mpd is not None and mrd is not None:
|
||||
# Perceptual feature matching via frozen discriminators
|
||||
with torch.no_grad():
|
||||
fmaps_real_mpd = mpd(target_t)
|
||||
fmaps_real_mrd = mrd(target_t)
|
||||
fmaps_gen_mpd = mpd(pred_t)
|
||||
fmaps_gen_mrd = mrd(pred_t)
|
||||
fm_loss = (
|
||||
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
|
||||
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
|
||||
)
|
||||
# Keep a small mel loss for stable frequency alignment
|
||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
|
||||
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
|
||||
else:
|
||||
# Fallback: mel L1 + multi-resolution STFT L1
|
||||
pred_mel = mel_converter(pred_t.squeeze(1))
|
||||
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
|
||||
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
|
||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||
primary_loss = mel_loss + stft_loss
|
||||
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
|
||||
|
||||
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
|
||||
# ── L2-SP regularization ─────────────────────────────────────────
|
||||
l2sp_loss = torch.zeros(1, device=device)
|
||||
if lambda_l2sp > 0.0 and ref_params:
|
||||
for name, param in vocoder.named_parameters():
|
||||
if name in ref_params and param.requires_grad:
|
||||
l2sp_loss = l2sp_loss + F.mse_loss(
|
||||
param, ref_params[name], reduction="sum"
|
||||
)
|
||||
l2sp_loss = l2sp_loss * lambda_l2sp
|
||||
|
||||
loss = mel_loss + stft_loss
|
||||
loss = primary_loss + l2sp_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
|
||||
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
||||
optimizer.step()
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
|
||||
print(f"[BigVGAN] {step+1}/{steps} "
|
||||
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} "
|
||||
f"total={loss.item():.4f}", flush=True)
|
||||
l2sp_str = f" l2sp={l2sp_loss.item():.4e}" if lambda_l2sp > 0 else ""
|
||||
print(f"[BigVGAN] {step+1}/{steps} {loss_desc}"
|
||||
f" total={loss.item():.4f}{l2sp_str}", flush=True)
|
||||
|
||||
if (step + 1) % save_every == 0 and (step + 1) < steps:
|
||||
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
|
||||
|
||||
Reference in New Issue
Block a user