feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes

BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 12:04:05 +02:00
parent f17f6f0863
commit 1e9551152e
5 changed files with 1159 additions and 44 deletions
+1
View File
@@ -21,6 +21,7 @@ _NODES = {
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
}
for key, (module_path, class_name, display_name) in _NODES.items():
+312 -36
View File
@@ -1,14 +1,29 @@
"""SelVA BigVGAN Vocoder Fine-tuner.
Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using
spectral reconstruction losses. The DiT and VAE are completely untouched.
Tier-1 approach based on research: snake alpha fine-tuning + L2-SP anchor
regularization + optional frozen discriminator feature matching.
Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1.
No GAN discriminator — this is a proof-of-concept to verify that the vocoder
can absorb BJ timbral characteristics before investing in full adversarial training.
Root cause of harmonic smearing with plain mel/STFT losses:
Spectral L1 minimizes expected reconstruction error — averaging over
high-variance harmonics. This is a loss-function topology problem, not
an LR/step-count problem. The fix is either (a) restrict trainable params
so the model lacks capacity to smear, or (b) use a perceptual loss that
penalizes harmonic averaging.
Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN
checkpoint so it can be loaded with SelVA BigVGAN Loader.
Tier-1 implementation:
1. snake_alpha_only mode — only tune ~5K per-channel α parameters in
Snake/SnakeBeta activations. These control harmonic periodicity per
channel. With only 5K trainable params, the model physically cannot
reshape the spectrum enough to cause the "green smear".
2. L2-SP anchor loss — penalizes parameter drift from pretrained values
(strictly better than weight decay, which anchors to zero).
3. Frozen discriminator feature matching — if a BigVGAN discriminator
checkpoint is provided, the pretrained MPD+MRD networks are used as
fixed perceptual feature extractors. Feature matching loss penalizes
harmonic smearing directly without any GAN instability.
Save format: {'generator': vocoder.state_dict()} — same as the original
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
"""
import random
@@ -16,6 +31,7 @@ import threading
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import comfy.utils
@@ -23,12 +39,133 @@ import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom,
and saves as a greyscale PNG with a simple viridis-like colour map.
"""
# ---------------------------------------------------------------------------
# Minimal MPD + MRD discriminators matching BigVGAN pretrained checkpoint keys
# ---------------------------------------------------------------------------
def _get_pad(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class _DiscriminatorP(nn.Module):
"""Multi-Period Discriminator sub-module (HiFi-GAN / BigVGAN style)."""
def __init__(self, period):
super().__init__()
self.period = period
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(32, 128, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(128, 512, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(512, 1024,(5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(1024,1024,(5, 1), 1, (_get_pad(5, 1), 0))),
])
self.conv_post = norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(self, x):
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiPeriodDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList([
_DiscriminatorP(p) for p in [2, 3, 5, 7, 11]
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
class _DiscriminatorR(nn.Module):
"""Multi-Resolution Discriminator sub-module."""
def __init__(self, fft_size, shift_size, win_length):
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def spectrogram(self, x):
"""x: [B, 1, T] → [B, 1, freq, time]"""
n, hop, win = self.fft_size, self.shift_size, self.win_length
window = torch.hann_window(win, device=x.device)
x = x.squeeze(1) # [B, T]
pad = (win - hop) // 2
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
return x
def forward(self, x):
fmap = []
x = self.spectrogram(x)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiResolutionDiscriminator(nn.Module):
def __init__(self):
super().__init__()
resolutions = [(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]
self.discriminators = nn.ModuleList([
_DiscriminatorR(*r) for r in resolutions
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
def _feature_matching_loss(fmaps_real, fmaps_gen):
"""L1 between paired feature map lists (both already detach-safe for real)."""
loss = torch.zeros(1, device=fmaps_gen[0].device)
for fr, fg in zip(fmaps_real, fmaps_gen):
T = min(fr.shape[-1], fg.shape[-1])
loss = loss + F.l1_loss(fg[..., :T], fr[..., :T].detach())
return loss / len(fmaps_real)
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep)."""
try:
from PIL import Image
import numpy as np
@@ -120,6 +257,10 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
return loss / len(_STFT_RESOLUTIONS)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaBigvganTrainer:
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
@@ -128,10 +269,10 @@ class SelvaBigvganTrainer:
RETURN_NAMES = ("checkpoint_path",)
OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",)
DESCRIPTION = (
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using "
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. "
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. "
"Load the result with SelVA BigVGAN Loader."
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips. "
"Default mode (snake_alpha_only) tunes only the ~5K Snake activation α "
"parameters — cannot cause harmonic smearing. Add a discriminator path "
"for perceptual feature matching loss. DiT and VAE stay frozen."
)
@classmethod
@@ -147,26 +288,53 @@ class SelvaBigvganTrainer:
"default": "bigvgan_bj.pt",
"tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.",
}),
"train_mode": (["snake_alpha_only", "all_params"], {
"default": "snake_alpha_only",
"tooltip": (
"snake_alpha_only: only tune ~5K per-channel α parameters in Snake/SnakeBeta "
"activations. These control harmonic periodicity. Cannot cause spectral smearing. "
"all_params: tune all vocoder weights — set lambda_l2sp>0 to prevent drift."
),
}),
"steps": ("INT", {
"default": 2000, "min": 100, "max": 50000,
"tooltip": "Training steps. 10002000 is a good first experiment.",
"tooltip": "Training steps. 10002000 is a good first experiment with snake_alpha_only.",
}),
"lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5,
"tooltip": "Learning rate. BigVGAN default is 1e-4.",
"tooltip": "Learning rate. 1e-4 for snake_alpha_only, 1e-5 for all_params.",
}),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32}),
"segment_seconds": ("FLOAT", {
"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25,
"tooltip": "Audio segment length per training sample in seconds.",
}),
"lambda_l2sp": ("FLOAT", {
"default": 1e-3, "min": 0.0, "max": 0.1, "step": 1e-4,
"tooltip": (
"L2-SP anchor regularization: penalizes parameter drift from pretrained values. "
"0 = disabled. 1e-3 is good for snake_alpha_only. "
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
),
}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"discriminator_path": ("STRING", {
"default": "",
"tooltip": (
"Optional path to BigVGAN discriminator checkpoint "
"(bigvgan_discriminator_optimizer.pt from the BigVGAN pretrained release). "
"When provided, frozen MPD+MRD feature matching replaces mel L1 — "
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
),
}),
},
}
def train(self, model, data_dir, output_path, steps, lr, batch_size,
segment_seconds, save_every, seed):
def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
import traceback
device = get_device()
@@ -197,6 +365,14 @@ class SelvaBigvganTrainer:
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
disc_path = None
if discriminator_path and discriminator_path.strip():
disc_path = Path(discriminator_path.strip())
if not disc_path.is_absolute():
disc_path = Path(folder_paths.get_output_directory()) / disc_path
if not disc_path.exists():
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
# Find and pre-load audio clips
segment_samples = int(segment_seconds * sample_rate)
audio_files = []
@@ -227,8 +403,15 @@ class SelvaBigvganTrainer:
raise RuntimeError(
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
)
print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s "
f"steps={steps} lr={lr} batch={batch_size}\n", flush=True)
trainable_count = sum(
1 for n, _ in vocoder.named_parameters() if "alpha" in n
) if train_mode == "snake_alpha_only" else sum(
1 for _ in vocoder.parameters()
)
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
f"segment={segment_seconds}s steps={steps} lr={lr} "
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
if strategy == "offload_to_cpu":
feature_utils.to(device)
@@ -259,8 +442,8 @@ class SelvaBigvganTrainer:
vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar,
train_mode, steps, lr, batch_size, lambda_l2sp,
save_every, seed, out_path, disc_path, pbar,
)
except Exception as e:
_exc[0] = e
@@ -275,11 +458,15 @@ class SelvaBigvganTrainer:
return (_result[0],)
# ---------------------------------------------------------------------------
# Training worker
# ---------------------------------------------------------------------------
def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed,
out_path, pbar):
train_mode, steps, lr, batch_size, lambda_l2sp,
save_every, seed, out_path, disc_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active.
Even though inference_mode is off here, tensors created in the calling
@@ -372,7 +559,65 @@ def _do_train(vocoder, mel_converter, clips,
if buf is not None:
module._buffers[bname] = buf.clone()
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99))
# ── Training mode: select which parameters to train ──────────────────────
if train_mode == "snake_alpha_only":
alpha_params = []
for name, param in vocoder.named_parameters():
if "alpha" in name:
param.requires_grad_(True)
alpha_params.append(param)
else:
param.requires_grad_(False)
n_trainable = sum(p.numel() for p in alpha_params)
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
f"({len(alpha_params)} alpha tensors)", flush=True)
trainable_params = alpha_params
else: # all_params
for param in vocoder.parameters():
param.requires_grad_(True)
n_trainable = sum(p.numel() for p in vocoder.parameters())
print(f"[BigVGAN] all_params: {n_trainable} trainable params", flush=True)
trainable_params = list(vocoder.parameters())
# ── L2-SP: cache reference parameter values (before any gradient steps) ──
ref_params = {}
if lambda_l2sp > 0.0:
for name, param in vocoder.named_parameters():
if param.requires_grad:
ref_params[name] = param.data.clone().detach()
print(f"[BigVGAN] L2-SP anchor: {len(ref_params)} params λ={lambda_l2sp}", flush=True)
# ── Optional: load pretrained discriminator for feature matching ──────────
mpd = mrd = None
if disc_path is not None:
try:
ckpt_d = torch.load(str(disc_path), map_location="cpu", weights_only=False)
mpd = _MultiPeriodDiscriminator()
mrd = _MultiResolutionDiscriminator()
# Try common key names used by different BigVGAN releases
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
if mpd_key in ckpt_d:
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
break
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
if mrd_key in ckpt_d:
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
break
mpd.to(device).eval()
mrd.to(device).eval()
for p in mpd.parameters():
p.requires_grad_(False)
for p in mrd.parameters():
p.requires_grad_(False)
print(f"[BigVGAN] Frozen discriminators ready for feature matching", flush=True)
except Exception as e:
print(f"[BigVGAN] WARNING: Could not load discriminator ({e}), "
f"falling back to mel+STFT losses", flush=True)
mpd = mrd = None
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
vocoder.train()
try:
@@ -396,24 +641,55 @@ def _do_train(vocoder, mel_converter, clips,
pred_t = pred_wav[..., :T]
target_t = target_wav[..., :T]
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel']
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
# ── Compute loss ─────────────────────────────────────────────────
if mpd is not None and mrd is not None:
# Perceptual feature matching via frozen discriminators
with torch.no_grad():
fmaps_real_mpd = mpd(target_t)
fmaps_real_mrd = mrd(target_t)
fmaps_gen_mpd = mpd(pred_t)
fmaps_gen_mrd = mrd(pred_t)
fm_loss = (
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
)
# Keep a small mel loss for stable frequency alignment
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
else:
# Fallback: mel L1 + multi-resolution STFT L1
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
primary_loss = mel_loss + stft_loss
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
# ── L2-SP regularization ─────────────────────────────────────────
l2sp_loss = torch.zeros(1, device=device)
if lambda_l2sp > 0.0 and ref_params:
for name, param in vocoder.named_parameters():
if name in ref_params and param.requires_grad:
l2sp_loss = l2sp_loss + F.mse_loss(
param, ref_params[name], reduction="sum"
)
l2sp_loss = l2sp_loss * lambda_l2sp
loss = mel_loss + stft_loss
loss = primary_loss + l2sp_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
pbar.update(1)
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
print(f"[BigVGAN] {step+1}/{steps} "
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} "
f"total={loss.item():.4f}", flush=True)
l2sp_str = f" l2sp={l2sp_loss.item():.4e}" if lambda_l2sp > 0 else ""
print(f"[BigVGAN] {step+1}/{steps} {loss_desc}"
f" total={loss.item():.4f}{l2sp_str}", flush=True)
if (step + 1) % save_every == 0 and (step + 1) < steps:
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
+434
View File
@@ -0,0 +1,434 @@
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against BJ reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against BJ reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import random
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_sampler import SelvaSampler
from .selva_textual_inversion_trainer import _inject_tokens
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached)
Mean spectrum loss captures the spectral envelope (which harmonics are
boosted). Gram matrix loss captures timbral texture — covariance between
frequency bands — without requiring temporal alignment.
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
# Gram matrix loss (texture, position-invariant)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + 0.1 * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the BJ reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 2.0, "step": 0.01,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.1 is the DITTO paper default.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "Weight of the BJ style loss. Increase to push harder toward "
"BJ style at the cost of coherence with the video.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match BJ reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load and precompute reference mel statistics
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device)
ref_mels = []
with torch.no_grad():
for rf in ref_files[:32]: # cap at 32 for speed
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, dtype)
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
ref_mels.append(mel)
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_mels:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels])
ref_mean = all_means.mean(0) # [n_mels]
all_grams = []
for m in ref_mels:
M = m.squeeze(0) # [n_mels, T]
all_grams.append((M @ M.T) / M.shape[-1])
ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels]
print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread)
ref_mean = ref_mean.clone().detach()
ref_gram = ref_gram.clone().detach()
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype)
sync_f = features["sync_features"].to(device, dtype)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# This is cheaper than checkpointing all steps, at the cost of an
# approximate (truncated) gradient. The gradient still flows through
# n_grad_steps steps, which is sufficient for meaningful x_0 updates.
with torch.no_grad():
x = x0
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Detach and re-leaf so backward only goes n_grad_steps deep.
# We treat x_k as a new leaf but seed it from x_0's value — so at
# opt step 0 the gradient is a true n_grad_steps truncated BPTT,
# and x_0 gets updated via x_k's dependence on x_0 through the
# no-grad prefix (approximation: gradient doesn't flow through prefix).
#
# Richer alternative: full checkpointing through all steps (uncomment
# the checkpoint block below and remove the no-grad prefix).
x = x.detach().requires_grad_(True)
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
x_unnorm = net_generator.unnormalize(x)
mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T]
# ── Style loss ───────────────────────────────────────────────────────
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
optimizer.zero_grad()
loss.backward()
# Propagate gradient from x (grad_fn leaf) back to x_0.
# x was detached from x_0, so we manually transfer the gradient:
# the no-grad prefix is an approximation — skip this if doing full
# checkpointing (x would have grad_fn pointing back to x_0).
# Here x.grad is the gradient w.r.t. x at step n_free_steps;
# we directly add it to x_0.grad as an approximation.
if x.grad is not None:
if x0.grad is None:
x0.grad = x.grad.clone()
else:
x0.grad.add_(x.grad)
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)