feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes

BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 12:04:05 +02:00
parent f17f6f0863
commit 1e9551152e
5 changed files with 1159 additions and 44 deletions
+312 -36
View File
@@ -1,14 +1,29 @@
"""SelVA BigVGAN Vocoder Fine-tuner.
Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using
spectral reconstruction losses. The DiT and VAE are completely untouched.
Tier-1 approach based on research: snake alpha fine-tuning + L2-SP anchor
regularization + optional frozen discriminator feature matching.
Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1.
No GAN discriminator — this is a proof-of-concept to verify that the vocoder
can absorb BJ timbral characteristics before investing in full adversarial training.
Root cause of harmonic smearing with plain mel/STFT losses:
Spectral L1 minimizes expected reconstruction error — averaging over
high-variance harmonics. This is a loss-function topology problem, not
an LR/step-count problem. The fix is either (a) restrict trainable params
so the model lacks capacity to smear, or (b) use a perceptual loss that
penalizes harmonic averaging.
Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN
checkpoint so it can be loaded with SelVA BigVGAN Loader.
Tier-1 implementation:
1. snake_alpha_only mode — only tune ~5K per-channel α parameters in
Snake/SnakeBeta activations. These control harmonic periodicity per
channel. With only 5K trainable params, the model physically cannot
reshape the spectrum enough to cause the "green smear".
2. L2-SP anchor loss — penalizes parameter drift from pretrained values
(strictly better than weight decay, which anchors to zero).
3. Frozen discriminator feature matching — if a BigVGAN discriminator
checkpoint is provided, the pretrained MPD+MRD networks are used as
fixed perceptual feature extractors. Feature matching loss penalizes
harmonic smearing directly without any GAN instability.
Save format: {'generator': vocoder.state_dict()} — same as the original
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
"""
import random
@@ -16,6 +31,7 @@ import threading
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import comfy.utils
@@ -23,12 +39,133 @@ import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom,
and saves as a greyscale PNG with a simple viridis-like colour map.
"""
# ---------------------------------------------------------------------------
# Minimal MPD + MRD discriminators matching BigVGAN pretrained checkpoint keys
# ---------------------------------------------------------------------------
def _get_pad(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class _DiscriminatorP(nn.Module):
"""Multi-Period Discriminator sub-module (HiFi-GAN / BigVGAN style)."""
def __init__(self, period):
super().__init__()
self.period = period
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(32, 128, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(128, 512, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(512, 1024,(5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(1024,1024,(5, 1), 1, (_get_pad(5, 1), 0))),
])
self.conv_post = norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(self, x):
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiPeriodDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList([
_DiscriminatorP(p) for p in [2, 3, 5, 7, 11]
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
class _DiscriminatorR(nn.Module):
"""Multi-Resolution Discriminator sub-module."""
def __init__(self, fft_size, shift_size, win_length):
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def spectrogram(self, x):
"""x: [B, 1, T] → [B, 1, freq, time]"""
n, hop, win = self.fft_size, self.shift_size, self.win_length
window = torch.hann_window(win, device=x.device)
x = x.squeeze(1) # [B, T]
pad = (win - hop) // 2
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
return x
def forward(self, x):
fmap = []
x = self.spectrogram(x)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiResolutionDiscriminator(nn.Module):
def __init__(self):
super().__init__()
resolutions = [(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]
self.discriminators = nn.ModuleList([
_DiscriminatorR(*r) for r in resolutions
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
def _feature_matching_loss(fmaps_real, fmaps_gen):
"""L1 between paired feature map lists (both already detach-safe for real)."""
loss = torch.zeros(1, device=fmaps_gen[0].device)
for fr, fg in zip(fmaps_real, fmaps_gen):
T = min(fr.shape[-1], fg.shape[-1])
loss = loss + F.l1_loss(fg[..., :T], fr[..., :T].detach())
return loss / len(fmaps_real)
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep)."""
try:
from PIL import Image
import numpy as np
@@ -120,6 +257,10 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
return loss / len(_STFT_RESOLUTIONS)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaBigvganTrainer:
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
@@ -128,10 +269,10 @@ class SelvaBigvganTrainer:
RETURN_NAMES = ("checkpoint_path",)
OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",)
DESCRIPTION = (
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. "
"Load the result with SelVA BigVGAN Loader."
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips. "
"Default mode (snake_alpha_only) tunes only the ~5K Snake activation α "
"parameters — cannot cause harmonic smearing. Add a discriminator path "
"for perceptual feature matching loss. DiT and VAE stay frozen."
)
@classmethod
@@ -147,26 +288,53 @@ class SelvaBigvganTrainer:
"default": "bigvgan_bj.pt",
"tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.",
}),
"train_mode": (["snake_alpha_only", "all_params"], {
"default": "snake_alpha_only",
"tooltip": (
"snake_alpha_only: only tune ~5K per-channel α parameters in Snake/SnakeBeta "
"activations. These control harmonic periodicity. Cannot cause spectral smearing. "
"all_params: tune all vocoder weights — set lambda_l2sp>0 to prevent drift."
),
}),
"steps": ("INT", {
"default": 2000, "min": 100, "max": 50000,
"tooltip": "Training steps. 10002000 is a good first experiment.",
"tooltip": "Training steps. 10002000 is a good first experiment with snake_alpha_only.",
}),
"lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5,
"tooltip": "Learning rate. BigVGAN default is 1e-4.",
"tooltip": "Learning rate. 1e-4 for snake_alpha_only, 1e-5 for all_params.",
}),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32}),
"segment_seconds": ("FLOAT", {
"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25,
"tooltip": "Audio segment length per training sample in seconds.",
}),
"lambda_l2sp": ("FLOAT", {
"default": 1e-3, "min": 0.0, "max": 0.1, "step": 1e-4,
"tooltip": (
"L2-SP anchor regularization: penalizes parameter drift from pretrained values. "
"0 = disabled. 1e-3 is good for snake_alpha_only. "
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
),
}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"discriminator_path": ("STRING", {
"default": "",
"tooltip": (
"Optional path to BigVGAN discriminator checkpoint "
"(bigvgan_discriminator_optimizer.pt from the BigVGAN pretrained release). "
"When provided, frozen MPD+MRD feature matching replaces mel L1 — "
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
),
}),
},
}
def train(self, model, data_dir, output_path, steps, lr, batch_size,
segment_seconds, save_every, seed):
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
import traceback
device = get_device()
@@ -197,6 +365,14 @@ class SelvaBigvganTrainer:
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
disc_path = None
if discriminator_path and discriminator_path.strip():
disc_path = Path(discriminator_path.strip())
if not disc_path.is_absolute():
disc_path = Path(folder_paths.get_output_directory()) / disc_path
if not disc_path.exists():
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
# Find and pre-load audio clips
segment_samples = int(segment_seconds * sample_rate)
audio_files = []
@@ -227,8 +403,15 @@ class SelvaBigvganTrainer:
raise RuntimeError(
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
)
print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s "
f"steps={steps} lr={lr} batch={batch_size}\n", flush=True)
trainable_count = sum(
1 for n, _ in vocoder.named_parameters() if "alpha" in n
) if train_mode == "snake_alpha_only" else sum(
1 for _ in vocoder.parameters()
)
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
f"segment={segment_seconds}s steps={steps} lr={lr} "
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
if strategy == "offload_to_cpu":
feature_utils.to(device)
@@ -259,8 +442,8 @@ class SelvaBigvganTrainer:
vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar,
train_mode, steps, lr, batch_size, lambda_l2sp,
save_every, seed, out_path, disc_path, pbar,
)
except Exception as e:
_exc[0] = e
@@ -275,11 +458,15 @@ class SelvaBigvganTrainer:
return (_result[0],)
# ---------------------------------------------------------------------------
# Training worker
# ---------------------------------------------------------------------------
def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar):
train_mode, steps, lr, batch_size, lambda_l2sp,
save_every, seed, out_path, disc_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active.
Even though inference_mode is off here, tensors created in the calling
@@ -372,7 +559,65 @@ def _do_train(vocoder, mel_converter, clips,
if buf is not None:
module._buffers[bname] = buf.clone()
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
# ── Training mode: select which parameters to train ──────────────────────
if train_mode == "snake_alpha_only":
alpha_params = []
for name, param in vocoder.named_parameters():
if "alpha" in name:
param.requires_grad_(True)
alpha_params.append(param)
else:
param.requires_grad_(False)
n_trainable = sum(p.numel() for p in alpha_params)
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
f"({len(alpha_params)} alpha tensors)", flush=True)
trainable_params = alpha_params
else: # all_params
for param in vocoder.parameters():
param.requires_grad_(True)
n_trainable = sum(p.numel() for p in vocoder.parameters())
print(f"[BigVGAN] all_params: {n_trainable} trainable params", flush=True)
trainable_params = list(vocoder.parameters())
# ── L2-SP: cache reference parameter values (before any gradient steps) ──
ref_params = {}
if lambda_l2sp > 0.0:
for name, param in vocoder.named_parameters():
if param.requires_grad:
ref_params[name] = param.data.clone().detach()
print(f"[BigVGAN] L2-SP anchor: {len(ref_params)} params λ={lambda_l2sp}", flush=True)
# ── Optional: load pretrained discriminator for feature matching ──────────
mpd = mrd = None
if disc_path is not None:
try:
ckpt_d = torch.load(str(disc_path), map_location="cpu", weights_only=False)
mpd = _MultiPeriodDiscriminator()
mrd = _MultiResolutionDiscriminator()
# Try common key names used by different BigVGAN releases
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
if mpd_key in ckpt_d:
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
break
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
if mrd_key in ckpt_d:
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
break
mpd.to(device).eval()
mrd.to(device).eval()
for p in mpd.parameters():
p.requires_grad_(False)
for p in mrd.parameters():
p.requires_grad_(False)
print(f"[BigVGAN] Frozen discriminators ready for feature matching", flush=True)
except Exception as e:
print(f"[BigVGAN] WARNING: Could not load discriminator ({e}), "
f"falling back to mel+STFT losses", flush=True)
mpd = mrd = None
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
vocoder.train()
try:
@@ -396,24 +641,55 @@ def _do_train(vocoder, mel_converter, clips,
pred_t = pred_wav[..., :T]
target_t = target_wav[..., :T]
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
# ── Compute loss ─────────────────────────────────────────────────
if mpd is not None and mrd is not None:
# Perceptual feature matching via frozen discriminators
with torch.no_grad():
fmaps_real_mpd = mpd(target_t)
fmaps_real_mrd = mrd(target_t)
fmaps_gen_mpd = mpd(pred_t)
fmaps_gen_mrd = mrd(pred_t)
fm_loss = (
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
)
# Keep a small mel loss for stable frequency alignment
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
else:
# Fallback: mel L1 + multi-resolution STFT L1
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
primary_loss = mel_loss + stft_loss
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
# ── L2-SP regularization ─────────────────────────────────────────
l2sp_loss = torch.zeros(1, device=device)
if lambda_l2sp > 0.0 and ref_params:
for name, param in vocoder.named_parameters():
if name in ref_params and param.requires_grad:
l2sp_loss = l2sp_loss + F.mse_loss(
param, ref_params[name], reduction="sum"
)
l2sp_loss = l2sp_loss * lambda_l2sp
loss = mel_loss + stft_loss
loss = primary_loss + l2sp_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
pbar.update(1)
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
print(f"[BigVGAN] {step+1}/{steps} "
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} "
f"total={loss.item():.4f}", flush=True)
l2sp_str = f" l2sp={l2sp_loss.item():.4e}" if lambda_l2sp > 0 else ""
print(f"[BigVGAN] {step+1}/{steps} {loss_desc}"
f" total={loss.item():.4f}{l2sp_str}", flush=True)
if (step + 1) % save_every == 0 and (step + 1) < steps:
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"