feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes

BigVGAN trainer (selva_bigvgan_trainer.py):
- Add snake_alpha_only train mode: tunes only ~27K per-channel α params
  (0.024% of 112M) — physically cannot cause harmonic smearing
- Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights
- Add optional discriminator_path: frozen MPD+MRD feature matching loss
  replaces mel L1 when a BigVGAN discriminator checkpoint is provided
- Inline MPD + MRD discriminator implementations (no extra dependencies)

DITTO optimizer (selva_ditto_optimizer.py):
- New node: inference-time noise optimization (arXiv:2401.12179)
- Optimizes x₀ via mel Gram matrix style loss against BJ reference clips
- All model weights frozen — zero quality degradation risk
- Truncated BPTT through last n_grad_steps of the ODE (configurable)
- Gradient checkpointing on each differentiated step

Docs:
- README: document all 20 nodes (was 3), add workflow diagrams
- STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers,
  why LoRA/TI fail, combined approach, dataset prep

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-09 12:04:05 +02:00
parent f17f6f0863
commit 1e9551152e
5 changed files with 1159 additions and 44 deletions
+254 -8
View File
@@ -58,7 +58,7 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From SelVA Model Loader | | `model` | From SelVA Model Loader (or any loader/loader chain) |
| `features` | From SelVA Feature Extractor | | `features` | From SelVA Feature Extractor |
| `prompt` | Text description — leave empty to use the prompt stored in features | | `prompt` | Text description — leave empty to use the prompt stored in features |
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) | | `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
@@ -66,22 +66,261 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| `steps` | Sampling steps (default: 25) | | `steps` | Sampling steps (default: 25) |
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) | | `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
| `seed` | RNG seed | | `seed` | RNG seed |
| `normalize` | Peak-normalize output to [-1, 1] (default: true) | | `normalize` | RMS-normalize output to `target_lufs` (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
**Output:** `AUDIO` **Output:** `AUDIO`
--- ---
## Workflow ### SelVA LoRA Loader
Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
**Output:** `model` (SELVA_MODEL with adapter injected)
---
### SelVA LoRA Trainer
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
---
### SelVA LoRA Scheduler
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `experiments_file` | Path to JSON sweep config |
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Skip Experiment
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
**Output:** `flag_path` (STRING)
---
### SelVA LoRA Evaluator
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
---
### SelVA Dataset Browser
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
**Outputs:** video path, audio path, frames directory, label, total count
---
### SelVA VAE Roundtrip
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `audio` | AUDIO to test |
**Output:** `audio_reconstructed` (AUDIO)
---
### SelVA HF Smoother
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
**Output:** `audio` (AUDIO)
---
### SelVA Spectral Matcher
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
**Output:** `audio` (AUDIO)
---
### SelVA Textual Inversion Trainer
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
---
### SelVA Textual Inversion Loader
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
---
### SelVA TI Scheduler
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Activation Steering Extractor
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with `.npz` feature files |
| `output_path` | Where to save `steering_vectors.pt` |
| `n_samples` | Clips to average over (default: 16) |
| `seed` | RNG seed |
**Output:** `steering_path` (STRING)
---
### SelVA Activation Steering Loader
Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
**Output:** `steering_vectors` (STEERING_VECTORS)
---
### SelVA BigVGAN Trainer
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with target-style audio files (searched recursively) |
| `output_path` | Where to save the fine-tuned vocoder `.pt` |
| `train_mode` | `snake_alpha_only` (default) or `all_params` |
| `steps` | Training steps (default: 2000) |
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
| `batch_size` | Clips per step (default: 4) |
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
| `save_every` | Checkpoint interval in steps (default: 500) |
| `seed` | RNG seed |
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
---
### SelVA BigVGAN Loader
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
---
### SelVA DITTO Optimizer
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `features` | From SelVA Feature Extractor |
| `prompt` | Sound description (leave empty to use features prompt) |
| `negative_prompt` | Sounds to suppress |
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
| `steps` | Euler steps for the final generation pass (default: 25) |
| `cfg_strength` | CFG scale (default: 4.5) |
| `seed` | RNG seed |
| `normalize` | *(optional)* RMS normalize output (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
**Output:** `AUDIO`
---
## Workflows
### Basic generation
``` ```
VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
│ (video_info) ─► (fps auto) │ (video_info)
│ (features) ────────────────────────────────────►│ │ (features) ──────────────────────────────────►│
│ (prompt) ──────────────────────────────────────►│ │ (prompt) ────────────────────────────────────►│
``` ```
Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction. ### DITTO style transfer (recommended first approach)
```
SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
SelVA Feature Extractor ──(features)────────────────────────────────────►│
(prompt) ──────────────────────────────────────►│
BJ reference_dir ───────────────────────────────────────────────────────►│
```
No training required. Each run optimizes x₀ independently for the current video and reference set.
### Vocoder fine-tuning
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
▲ ▲
checkpoint .pt SelVA Feature Extractor
```
### LoRA training
See [LORA_TRAINING.md](LORA_TRAINING.md).
--- ---
@@ -127,8 +366,15 @@ The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available,
--- ---
## Style Transfer
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
---
## Credits ## Credits
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training - [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework - [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
+158
View File
@@ -0,0 +1,158 @@
# Style Transfer for SelVA
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
---
## Why standard fine-tuning is hard
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~1015 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.210.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
The approaches below are ordered by expected quality and ease of use.
---
## Tier 1 — DITTO (recommended first try)
**Node: SelVA DITTO Optimizer**
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
**How it works:**
For each video clip you want to process:
1. Run SelVA Feature Extractor as usual.
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
```
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
SelVA Feature Extractor ──(features)────────────────────────►│
(prompt) ──────────────────────────►│
BJ clips ───────────────────────────(reference_dir) ─────────►│
```
**Tuning guide:**
| Parameter | Starting value | When to adjust |
|---|---|---|
| `n_opt_steps` | 50 | Increase to 100200 if style shift is too subtle |
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
| `style_weight` | 1.0 | Increase to 25 for stronger BJ character; watch for incoherence |
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~46 GB additional VRAM over baseline inference.
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 515 minutes depending on GPU.
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
---
## Tier 2 — Vocoder Fine-tuning
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
### Why plain mel L1 loss fails
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 38 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
### `snake_alpha_only` mode (default, recommended)
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
### `all_params` mode with discriminator
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
### Workflow
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor
```
### Tuning guide
| Parameter | Default | Notes |
|---|---|---|
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
| `steps` | 2000 | 10002000 for snake_alpha_only; 30005000 for all_params |
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
| `batch_size` | 4 | 48 for stable gradients |
| `segment_seconds` | 1.0 | 12 s segments recommended |
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
---
## Tier 3 — DITTO + Vocoder (combined)
Stack both:
```
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
```
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
---
## What doesn't work (and why)
### Standard LoRA
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
### Textual inversion
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
### Activation steering (global mean difference)
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 36 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
---
## Reference dataset preparation
Use the same audio clips for both DITTO and vocoder fine-tuning:
- **Minimum:** 2030 clips. DITTO works from 5+; vocoder benefits from 40+.
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
+1
View File
@@ -21,6 +21,7 @@ _NODES = {
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"), "SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"), "SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"), "SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
+310 -34
View File
@@ -1,14 +1,29 @@
"""SelVA BigVGAN Vocoder Fine-tuner. """SelVA BigVGAN Vocoder Fine-tuner.
Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using Tier-1 approach based on research: snake alpha fine-tuning + L2-SP anchor
spectral reconstruction losses. The DiT and VAE are completely untouched. regularization + optional frozen discriminator feature matching.
Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1. Root cause of harmonic smearing with plain mel/STFT losses:
No GAN discriminator — this is a proof-of-concept to verify that the vocoder Spectral L1 minimizes expected reconstruction error — averaging over
can absorb BJ timbral characteristics before investing in full adversarial training. high-variance harmonics. This is a loss-function topology problem, not
an LR/step-count problem. The fix is either (a) restrict trainable params
so the model lacks capacity to smear, or (b) use a perceptual loss that
penalizes harmonic averaging.
Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN Tier-1 implementation:
checkpoint so it can be loaded with SelVA BigVGAN Loader. 1. snake_alpha_only mode — only tune ~5K per-channel α parameters in
Snake/SnakeBeta activations. These control harmonic periodicity per
channel. With only 5K trainable params, the model physically cannot
reshape the spectrum enough to cause the "green smear".
2. L2-SP anchor loss — penalizes parameter drift from pretrained values
(strictly better than weight decay, which anchors to zero).
3. Frozen discriminator feature matching — if a BigVGAN discriminator
checkpoint is provided, the pretrained MPD+MRD networks are used as
fixed perceptual feature extractors. Feature matching loss penalizes
harmonic smearing directly without any GAN instability.
Save format: {'generator': vocoder.state_dict()} — same as the original
BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader.
""" """
import random import random
@@ -16,6 +31,7 @@ import threading
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
import comfy.utils import comfy.utils
@@ -23,12 +39,133 @@ import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).
Normalises to [0, 255], flips frequency axis so low freqs are at the bottom, # ---------------------------------------------------------------------------
and saves as a greyscale PNG with a simple viridis-like colour map. # Minimal MPD + MRD discriminators matching BigVGAN pretrained checkpoint keys
""" # ---------------------------------------------------------------------------
def _get_pad(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class _DiscriminatorP(nn.Module):
"""Multi-Period Discriminator sub-module (HiFi-GAN / BigVGAN style)."""
def __init__(self, period):
super().__init__()
self.period = period
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(32, 128, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(128, 512, (5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(512, 1024,(5, 1), (3, 1), (_get_pad(5, 1), 0))),
norm(nn.Conv2d(1024,1024,(5, 1), 1, (_get_pad(5, 1), 0))),
])
self.conv_post = norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0)))
def forward(self, x):
fmap = []
b, c, t = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiPeriodDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.discriminators = nn.ModuleList([
_DiscriminatorP(p) for p in [2, 3, 5, 7, 11]
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
class _DiscriminatorR(nn.Module):
"""Multi-Resolution Discriminator sub-module."""
def __init__(self, fft_size, shift_size, win_length):
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
from torch.nn.utils.parametrizations import weight_norm
norm = weight_norm
self.convs = nn.ModuleList([
norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def spectrogram(self, x):
"""x: [B, 1, T] → [B, 1, freq, time]"""
n, hop, win = self.fft_size, self.shift_size, self.win_length
window = torch.hann_window(win, device=x.device)
x = x.squeeze(1) # [B, T]
pad = (win - hop) // 2
x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect")
x = torch.stft(x, n, hop, win, window, center=False, return_complex=True)
x = x.abs().unsqueeze(1) # [B, 1, freq, time]
return x
def forward(self, x):
fmap = []
x = self.spectrogram(x)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
return fmap
class _MultiResolutionDiscriminator(nn.Module):
def __init__(self):
super().__init__()
resolutions = [(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)]
self.discriminators = nn.ModuleList([
_DiscriminatorR(*r) for r in resolutions
])
def forward(self, y):
fmaps = []
for d in self.discriminators:
fmaps.extend(d(y))
return fmaps
def _feature_matching_loss(fmaps_real, fmaps_gen):
"""L1 between paired feature map lists (both already detach-safe for real)."""
loss = torch.zeros(1, device=fmaps_gen[0].device)
for fr, fg in zip(fmaps_real, fmaps_gen):
T = min(fr.shape[-1], fg.shape[-1])
loss = loss + F.l1_loss(fg[..., :T], fr[..., :T].detach())
return loss / len(fmaps_real)
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def _save_spectrogram(path, mel_tensor):
"""Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep)."""
try: try:
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@@ -120,6 +257,10 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device):
return loss / len(_STFT_RESOLUTIONS) return loss / len(_STFT_RESOLUTIONS)
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaBigvganTrainer: class SelvaBigvganTrainer:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
@@ -128,10 +269,10 @@ class SelvaBigvganTrainer:
RETURN_NAMES = ("checkpoint_path",) RETURN_NAMES = ("checkpoint_path",)
OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",) OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",)
DESCRIPTION = ( DESCRIPTION = (
"Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using " "Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips. "
"spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. " "Default mode (snake_alpha_only) tunes only the ~5K Snake activation α "
"Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. " "parameters — cannot cause harmonic smearing. Add a discriminator path "
"Load the result with SelVA BigVGAN Loader." "for perceptual feature matching loss. DiT and VAE stay frozen."
) )
@classmethod @classmethod
@@ -147,26 +288,53 @@ class SelvaBigvganTrainer:
"default": "bigvgan_bj.pt", "default": "bigvgan_bj.pt",
"tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.", "tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.",
}), }),
"train_mode": (["snake_alpha_only", "all_params"], {
"default": "snake_alpha_only",
"tooltip": (
"snake_alpha_only: only tune ~5K per-channel α parameters in Snake/SnakeBeta "
"activations. These control harmonic periodicity. Cannot cause spectral smearing. "
"all_params: tune all vocoder weights — set lambda_l2sp>0 to prevent drift."
),
}),
"steps": ("INT", { "steps": ("INT", {
"default": 2000, "min": 100, "max": 50000, "default": 2000, "min": 100, "max": 50000,
"tooltip": "Training steps. 10002000 is a good first experiment.", "tooltip": "Training steps. 10002000 is a good first experiment with snake_alpha_only.",
}), }),
"lr": ("FLOAT", { "lr": ("FLOAT", {
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5, "default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5,
"tooltip": "Learning rate. BigVGAN default is 1e-4.", "tooltip": "Learning rate. 1e-4 for snake_alpha_only, 1e-5 for all_params.",
}), }),
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32}), "batch_size": ("INT", {"default": 4, "min": 1, "max": 32}),
"segment_seconds": ("FLOAT", { "segment_seconds": ("FLOAT", {
"default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25, "default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25,
"tooltip": "Audio segment length per training sample in seconds.", "tooltip": "Audio segment length per training sample in seconds.",
}), }),
"lambda_l2sp": ("FLOAT", {
"default": 1e-3, "min": 0.0, "max": 0.1, "step": 1e-4,
"tooltip": (
"L2-SP anchor regularization: penalizes parameter drift from pretrained values. "
"0 = disabled. 1e-3 is good for snake_alpha_only. "
"Increase to 1e-2 for all_params to prevent catastrophic forgetting."
),
}),
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
}, },
"optional": {
"discriminator_path": ("STRING", {
"default": "",
"tooltip": (
"Optional path to BigVGAN discriminator checkpoint "
"(bigvgan_discriminator_optimizer.pt from the BigVGAN pretrained release). "
"When provided, frozen MPD+MRD feature matching replaces mel L1 — "
"the key fix for harmonic smearing. Leave empty to use mel+STFT losses only."
),
}),
},
} }
def train(self, model, data_dir, output_path, steps, lr, batch_size, def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size,
segment_seconds, save_every, seed): segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""):
import traceback import traceback
device = get_device() device = get_device()
@@ -197,6 +365,14 @@ class SelvaBigvganTrainer:
out_path = Path(folder_paths.get_output_directory()) / out_path out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
disc_path = None
if discriminator_path and discriminator_path.strip():
disc_path = Path(discriminator_path.strip())
if not disc_path.is_absolute():
disc_path = Path(folder_paths.get_output_directory()) / disc_path
if not disc_path.exists():
raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}")
# Find and pre-load audio clips # Find and pre-load audio clips
segment_samples = int(segment_seconds * sample_rate) segment_samples = int(segment_seconds * sample_rate)
audio_files = [] audio_files = []
@@ -227,8 +403,15 @@ class SelvaBigvganTrainer:
raise RuntimeError( raise RuntimeError(
f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)" f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)"
) )
print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s "
f"steps={steps} lr={lr} batch={batch_size}\n", flush=True) trainable_count = sum(
1 for n, _ in vocoder.named_parameters() if "alpha" in n
) if train_mode == "snake_alpha_only" else sum(
1 for _ in vocoder.parameters()
)
print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} "
f"segment={segment_seconds}s steps={steps} lr={lr} "
f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True)
if strategy == "offload_to_cpu": if strategy == "offload_to_cpu":
feature_utils.to(device) feature_utils.to(device)
@@ -259,8 +442,8 @@ class SelvaBigvganTrainer:
vocoder, mel_converter, clips, vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils, device, dtype, strategy, feature_utils,
segment_samples, sample_rate, segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed, train_mode, steps, lr, batch_size, lambda_l2sp,
out_path, pbar, save_every, seed, out_path, disc_path, pbar,
) )
except Exception as e: except Exception as e:
_exc[0] = e _exc[0] = e
@@ -275,11 +458,15 @@ class SelvaBigvganTrainer:
return (_result[0],) return (_result[0],)
# ---------------------------------------------------------------------------
# Training worker
# ---------------------------------------------------------------------------
def _do_train(vocoder, mel_converter, clips, def _do_train(vocoder, mel_converter, clips,
device, dtype, strategy, feature_utils, device, dtype, strategy, feature_utils,
segment_samples, sample_rate, segment_samples, sample_rate,
steps, lr, batch_size, save_every, seed, train_mode, steps, lr, batch_size, lambda_l2sp,
out_path, pbar): save_every, seed, out_path, disc_path, pbar):
"""Execute training. Called in a fresh thread — no inference_mode active. """Execute training. Called in a fresh thread — no inference_mode active.
Even though inference_mode is off here, tensors created in the calling Even though inference_mode is off here, tensors created in the calling
@@ -372,7 +559,65 @@ def _do_train(vocoder, mel_converter, clips,
if buf is not None: if buf is not None:
module._buffers[bname] = buf.clone() module._buffers[bname] = buf.clone()
optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) # ── Training mode: select which parameters to train ──────────────────────
if train_mode == "snake_alpha_only":
alpha_params = []
for name, param in vocoder.named_parameters():
if "alpha" in name:
param.requires_grad_(True)
alpha_params.append(param)
else:
param.requires_grad_(False)
n_trainable = sum(p.numel() for p in alpha_params)
print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params "
f"({len(alpha_params)} alpha tensors)", flush=True)
trainable_params = alpha_params
else: # all_params
for param in vocoder.parameters():
param.requires_grad_(True)
n_trainable = sum(p.numel() for p in vocoder.parameters())
print(f"[BigVGAN] all_params: {n_trainable} trainable params", flush=True)
trainable_params = list(vocoder.parameters())
# ── L2-SP: cache reference parameter values (before any gradient steps) ──
ref_params = {}
if lambda_l2sp > 0.0:
for name, param in vocoder.named_parameters():
if param.requires_grad:
ref_params[name] = param.data.clone().detach()
print(f"[BigVGAN] L2-SP anchor: {len(ref_params)} params λ={lambda_l2sp}", flush=True)
# ── Optional: load pretrained discriminator for feature matching ──────────
mpd = mrd = None
if disc_path is not None:
try:
ckpt_d = torch.load(str(disc_path), map_location="cpu", weights_only=False)
mpd = _MultiPeriodDiscriminator()
mrd = _MultiResolutionDiscriminator()
# Try common key names used by different BigVGAN releases
for mpd_key in ("mpd", "discriminator_mpd", "MPD"):
if mpd_key in ckpt_d:
mpd.load_state_dict(ckpt_d[mpd_key], strict=False)
print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True)
break
for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"):
if mrd_key in ckpt_d:
mrd.load_state_dict(ckpt_d[mrd_key], strict=False)
print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True)
break
mpd.to(device).eval()
mrd.to(device).eval()
for p in mpd.parameters():
p.requires_grad_(False)
for p in mrd.parameters():
p.requires_grad_(False)
print(f"[BigVGAN] Frozen discriminators ready for feature matching", flush=True)
except Exception as e:
print(f"[BigVGAN] WARNING: Could not load discriminator ({e}), "
f"falling back to mel+STFT losses", flush=True)
mpd = mrd = None
optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99))
vocoder.train() vocoder.train()
try: try:
@@ -396,24 +641,55 @@ def _do_train(vocoder, mel_converter, clips,
pred_t = pred_wav[..., :T] pred_t = pred_wav[..., :T]
target_t = target_wav[..., :T] target_t = target_wav[..., :T]
pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel'] # ── Compute loss ─────────────────────────────────────────────────
if mpd is not None and mrd is not None:
# Perceptual feature matching via frozen discriminators
with torch.no_grad():
fmaps_real_mpd = mpd(target_t)
fmaps_real_mrd = mrd(target_t)
fmaps_gen_mpd = mpd(pred_t)
fmaps_gen_mrd = mrd(pred_t)
fm_loss = (
_feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) +
_feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd)
)
# Keep a small mel loss for stable frequency alignment
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
primary_loss = 2.0 * fm_loss + 0.1 * mel_loss
loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}"
else:
# Fallback: mel L1 + multi-resolution STFT L1
pred_mel = mel_converter(pred_t.squeeze(1))
T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) T_mel = min(pred_mel.shape[-1], target_mel.shape[-1])
mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel])
stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device)
primary_loss = mel_loss + stft_loss
loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}"
loss = mel_loss + stft_loss # ── L2-SP regularization ─────────────────────────────────────────
l2sp_loss = torch.zeros(1, device=device)
if lambda_l2sp > 0.0 and ref_params:
for name, param in vocoder.named_parameters():
if name in ref_params and param.requires_grad:
l2sp_loss = l2sp_loss + F.mse_loss(
param, ref_params[name], reduction="sum"
)
l2sp_loss = l2sp_loss * lambda_l2sp
loss = primary_loss + l2sp_loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step() optimizer.step()
pbar.update(1) pbar.update(1)
if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1:
print(f"[BigVGAN] {step+1}/{steps} " l2sp_str = f" l2sp={l2sp_loss.item():.4e}" if lambda_l2sp > 0 else ""
f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " print(f"[BigVGAN] {step+1}/{steps} {loss_desc}"
f"total={loss.item():.4f}", flush=True) f" total={loss.item():.4f}{l2sp_str}", flush=True)
if (step + 1) % save_every == 0 and (step + 1) < steps: if (step + 1) % save_every == 0 and (step + 1) < steps:
step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}"
+434
View File
@@ -0,0 +1,434 @@
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against BJ reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against BJ reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import random
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_sampler import SelvaSampler
from .selva_textual_inversion_trainer import _inject_tokens
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of BJ reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached)
Mean spectrum loss captures the spectral envelope (which harmonics are
boosted). Gram matrix loss captures timbral texture — covariance between
frequency bands — without requiring temporal alignment.
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
# Gram matrix loss (texture, position-invariant)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + 0.1 * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the BJ reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.1, "min": 0.001, "max": 2.0, "step": 0.01,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.1 is the DITTO paper default.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "Weight of the BJ style loss. Increase to push harder toward "
"BJ style at the cost of coherence with the video.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match BJ reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load and precompute reference mel statistics
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device)
ref_mels = []
with torch.no_grad():
for rf in ref_files[:32]: # cap at 32 for speed
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, dtype)
mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T]
ref_mels.append(mel)
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_mels:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels])
ref_mean = all_means.mean(0) # [n_mels]
all_grams = []
for m in ref_mels:
M = m.squeeze(0) # [n_mels, T]
all_grams.append((M @ M.T) / M.shape[-1])
ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels]
print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread)
ref_mean = ref_mean.clone().detach()
ref_gram = ref_gram.clone().detach()
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype)
sync_f = features["sync_features"].to(device, dtype)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# This is cheaper than checkpointing all steps, at the cost of an
# approximate (truncated) gradient. The gradient still flows through
# n_grad_steps steps, which is sufficient for meaningful x_0 updates.
with torch.no_grad():
x = x0
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Detach and re-leaf so backward only goes n_grad_steps deep.
# We treat x_k as a new leaf but seed it from x_0's value — so at
# opt step 0 the gradient is a true n_grad_steps truncated BPTT,
# and x_0 gets updated via x_k's dependence on x_0 through the
# no-grad prefix (approximation: gradient doesn't flow through prefix).
#
# Richer alternative: full checkpointing through all steps (uncomment
# the checkpoint block below and remove the no-grad prefix).
x = x.detach().requires_grad_(True)
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Decode to mel (no vocoder — cheap) ──────────────────────────────
x_unnorm = net_generator.unnormalize(x)
mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T]
# ── Style loss ───────────────────────────────────────────────────────
loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram)
optimizer.zero_grad()
loss.backward()
# Propagate gradient from x (grad_fn leaf) back to x_0.
# x was detached from x_0, so we manually transfer the gradient:
# the no-grad prefix is an approximation — skip this if doing full
# checkpointing (x would have grad_fn pointing back to x_0).
# Here x.grad is the gradient w.r.t. x at step n_free_steps;
# we directly add it to x_0.grad as an approximation.
if x.grad is not None:
if x0.grad is None:
x0.grad = x.grad.clone()
else:
x0.grad.add_(x.grad)
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)