From 1e9551152ebe648fce123edd5d7dfc0ed5f3ddcf Mon Sep 17 00:00:00 2001 From: Ethanfel Date: Thu, 9 Apr 2026 12:04:05 +0200 Subject: [PATCH] feat: add DITTO optimizer, upgrade BigVGAN trainer, document all nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BigVGAN trainer (selva_bigvgan_trainer.py): - Add snake_alpha_only train mode: tunes only ~27K per-channel α params (0.024% of 112M) — physically cannot cause harmonic smearing - Add lambda_l2sp: L2-SP anchor regularization toward pretrained weights - Add optional discriminator_path: frozen MPD+MRD feature matching loss replaces mel L1 when a BigVGAN discriminator checkpoint is provided - Inline MPD + MRD discriminator implementations (no extra dependencies) DITTO optimizer (selva_ditto_optimizer.py): - New node: inference-time noise optimization (arXiv:2401.12179) - Optimizes x₀ via mel Gram matrix style loss against BJ reference clips - All model weights frozen — zero quality degradation risk - Truncated BPTT through last n_grad_steps of the ODE (configurable) - Gradient checkpointing on each differentiated step Docs: - README: document all 20 nodes (was 3), add workflow diagrams - STYLE_TRANSFER.md: new guide — DITTO, vocoder fine-tuning tiers, why LoRA/TI fail, combined approach, dataset prep Co-Authored-By: Claude Sonnet 4.6 --- README.md | 262 +++++++++++++++++++- STYLE_TRANSFER.md | 158 ++++++++++++ nodes/__init__.py | 1 + nodes/selva_bigvgan_trainer.py | 348 +++++++++++++++++++++++--- nodes/selva_ditto_optimizer.py | 434 +++++++++++++++++++++++++++++++++ 5 files changed, 1159 insertions(+), 44 deletions(-) create mode 100644 STYLE_TRANSFER.md create mode 100644 nodes/selva_ditto_optimizer.py diff --git a/README.md b/README.md index f4c0966..87550d1 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Generates audio from video features. Runs the rectified flow ODE with classifier | Input | Description | |-------|-------------| -| `model` | From SelVA Model Loader | +| `model` | From SelVA Model Loader (or any loader/loader chain) | | `features` | From SelVA Feature Extractor | | `prompt` | Text description — leave empty to use the prompt stored in features | | `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) | @@ -66,22 +66,261 @@ Generates audio from video features. Runs the rectified flow ODE with classifier | `steps` | Sampling steps (default: 25) | | `cfg_strength` | Classifier-free guidance scale (default: 4.5) | | `seed` | RNG seed | -| `normalize` | Peak-normalize output to [-1, 1] (default: true) | +| `normalize` | RMS-normalize output to `target_lufs` (default: true) | +| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) | +| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader | +| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) | +| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader | +| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) | **Output:** `AUDIO` --- -## Workflow +### SelVA LoRA Loader + +Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL from Model Loader | +| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint | +| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated | + +**Output:** `model` (SELVA_MODEL with adapter injected) + +--- + +### SelVA LoRA Trainer + +Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide. + +**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING) + +--- + +### SelVA LoRA Scheduler + +Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL | +| `experiments_file` | Path to JSON sweep config | + +**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE) + +--- + +### SelVA Skip Experiment + +Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running. + +**Output:** `flag_path` (STRING) + +--- + +### SelVA LoRA Evaluator + +Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA). + +**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE) + +--- + +### SelVA Dataset Browser + +Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset. + +**Outputs:** video path, audio path, frames directory, label, total count + +--- + +### SelVA VAE Roundtrip + +Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL | +| `audio` | AUDIO to test | + +**Output:** `audio_reconstructed` (AUDIO) + +--- + +### SelVA HF Smoother + +Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets. + +**Output:** `audio` (AUDIO) + +--- + +### SelVA Spectral Matcher + +Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch. + +**Output:** `audio` (AUDIO) + +--- + +### SelVA Textual Inversion Trainer + +Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style. + +> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach. + +**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE) + +--- + +### SelVA Textual Inversion Loader + +Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input. + +**Output:** `textual_inversion` (TEXTUAL_INVERSION) + +--- + +### SelVA TI Scheduler + +Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs. + +**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE) + +--- + +### SelVA Activation Steering Extractor + +Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL | +| `data_dir` | Directory with `.npz` feature files | +| `output_path` | Where to save `steering_vectors.pt` | +| `n_samples` | Clips to average over (default: 16) | +| `seed` | RNG seed | + +**Output:** `steering_path` (STRING) + +--- + +### SelVA Activation Steering Loader + +Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input. + +**Output:** `steering_vectors` (STEERING_VECTORS) + +--- + +### SelVA BigVGAN Trainer + +Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched. + +Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL | +| `data_dir` | Directory with target-style audio files (searched recursively) | +| `output_path` | Where to save the fine-tuned vocoder `.pt` | +| `train_mode` | `snake_alpha_only` (default) or `all_params` | +| `steps` | Training steps (default: 2000) | +| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) | +| `batch_size` | Clips per step (default: 4) | +| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) | +| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) | +| `save_every` | Checkpoint interval in steps (default: 500) | +| `seed` | RNG seed | +| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing | + +**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader + +Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final. + +--- + +### SelVA BigVGAN Loader + +Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL from Model Loader | +| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) | + +**Output:** `model` (SELVA_MODEL with fine-tuned vocoder) + +--- + +### SelVA DITTO Optimizer + +Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk. + +Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass. + +| Input | Description | +|-------|-------------| +| `model` | SELVA_MODEL | +| `features` | From SelVA Feature Extractor | +| `prompt` | Sound description (leave empty to use features prompt) | +| `negative_prompt` | Sounds to suppress | +| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) | +| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) | +| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) | +| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) | +| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) | +| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) | +| `steps` | Euler steps for the final generation pass (default: 25) | +| `cfg_strength` | CFG scale (default: 4.5) | +| `seed` | RNG seed | +| `normalize` | *(optional)* RMS normalize output (default: true) | +| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) | + +**Output:** `AUDIO` + +--- + +## Workflows + +### Basic generation ``` -VHS LoadVideo ──► SelVA Feature Extractor ──────────────────────► SelVA Sampler ──► Save Audio - │ (video_info) ─► (fps auto) ▲ - │ (features) ────────────────────────────────────►│ - │ (prompt) ──────────────────────────────────────►│ +VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio + │ (video_info) ▲ + │ (features) ──────────────────────────────────►│ + │ (prompt) ────────────────────────────────────►│ ``` -Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction. +### DITTO style transfer (recommended first approach) + +``` +SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio + ▲ +SelVA Feature Extractor ──(features)────────────────────────────────────►│ + (prompt) ──────────────────────────────────────►│ +BJ reference_dir ───────────────────────────────────────────────────────►│ +``` + +No training required. Each run optimizes x₀ independently for the current video and reference set. + +### Vocoder fine-tuning + +``` +SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt) + ▲ +BJ audio clips ──(data_dir)──►│ + +SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio + ▲ ▲ + checkpoint .pt SelVA Feature Extractor +``` + +### LoRA training + +See [LORA_TRAINING.md](LORA_TRAINING.md). --- @@ -127,8 +366,15 @@ The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available, --- +## Style Transfer + +For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md). + +--- + ## Credits - [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training - [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis +- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization diff --git a/STYLE_TRANSFER.md b/STYLE_TRANSFER.md new file mode 100644 index 0000000..03e1e10 --- /dev/null +++ b/STYLE_TRANSFER.md @@ -0,0 +1,158 @@ +# Style Transfer for SelVA + +This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target. + +--- + +## Why standard fine-tuning is hard + +SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~10–15 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target. + +**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.21–0.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0. + +**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift. + +The approaches below are ordered by expected quality and ease of use. + +--- + +## Tier 1 — DITTO (recommended first try) + +**Node: SelVA DITTO Optimizer** + +Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point. + +**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass. + +**How it works:** + +For each video clip you want to process: +1. Run SelVA Feature Extractor as usual. +2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`. +3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`. +4. After optimization, one final full-ODE pass generates the output audio from the refined x₀. + +``` +SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio + ▲ +SelVA Feature Extractor ──(features)────────────────────────►│ + (prompt) ──────────────────────────►│ +BJ clips ───────────────────────────(reference_dir) ─────────►│ +``` + +**Tuning guide:** + +| Parameter | Starting value | When to adjust | +|---|---|---| +| `n_opt_steps` | 50 | Increase to 100–200 if style shift is too subtle | +| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift | +| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient | +| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps | +| `style_weight` | 1.0 | Increase to 2–5 for stronger BJ character; watch for incoherence | + +**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~4–6 GB additional VRAM over baseline inference. + +**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 5–15 minutes depending on GPU. + +**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that. + +--- + +## Tier 2 — Vocoder Fine-tuning + +**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader** + +The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes. + +### Why plain mel L1 loss fails + +BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 3–8 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count. + +### `snake_alpha_only` mode (default, recommended) + +BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder. + +With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear. + +**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape. + +### `all_params` mode with discriminator + +For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release): + +1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors. +2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space. +3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips. + +This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint. + +### Workflow + +``` +SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt + ▲ +BJ audio clips ──(data_dir)──►│ + +SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler + ▲ ▲ + bigvgan_bj.pt SelVA Feature Extractor +``` + +### Tuning guide + +| Parameter | Default | Notes | +|---|---|---| +| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path | +| `steps` | 2000 | 1000–2000 for snake_alpha_only; 3000–5000 for all_params | +| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params | +| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift | +| `batch_size` | 4 | 4–8 for stable gradients | +| `segment_seconds` | 1.0 | 1–2 s segments recommended | + +**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline. + +--- + +## Tier 3 — DITTO + Vocoder (combined) + +Stack both: + +``` +SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio + ▲ ▲ + bigvgan_bj.pt SelVA Feature Extractor + reference_dir +``` + +The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer. + +--- + +## What doesn't work (and why) + +### Standard LoRA + +LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it. + +PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements. + +### Textual inversion + +SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse. + +### Activation steering (global mean difference) + +The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 3–6 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first. + +--- + +## Reference dataset preparation + +Use the same audio clips for both DITTO and vocoder fine-tuning: + +- **Minimum:** 20–30 clips. DITTO works from 5+; vocoder benefits from 40+. +- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically. +- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally. +- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip. +- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop. + +Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence. diff --git a/nodes/__init__.py b/nodes/__init__.py index f75434a..9f447ec 100644 --- a/nodes/__init__.py +++ b/nodes/__init__.py @@ -21,6 +21,7 @@ _NODES = { "SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"), "SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"), "SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"), + "SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"), } for key, (module_path, class_name, display_name) in _NODES.items(): diff --git a/nodes/selva_bigvgan_trainer.py b/nodes/selva_bigvgan_trainer.py index 564576c..0bc2a11 100644 --- a/nodes/selva_bigvgan_trainer.py +++ b/nodes/selva_bigvgan_trainer.py @@ -1,14 +1,29 @@ """SelVA BigVGAN Vocoder Fine-tuner. -Fine-tunes only the BigVGAN vocoder (mel → waveform) on BJ audio clips using -spectral reconstruction losses. The DiT and VAE are completely untouched. +Tier-1 approach based on research: snake alpha fine-tuning + L2-SP anchor +regularization + optional frozen discriminator feature matching. -Loss: L1 mel reconstruction + multi-resolution STFT magnitude L1. -No GAN discriminator — this is a proof-of-concept to verify that the vocoder -can absorb BJ timbral characteristics before investing in full adversarial training. +Root cause of harmonic smearing with plain mel/STFT losses: + Spectral L1 minimizes expected reconstruction error — averaging over + high-variance harmonics. This is a loss-function topology problem, not + an LR/step-count problem. The fix is either (a) restrict trainable params + so the model lacks capacity to smear, or (b) use a perceptual loss that + penalizes harmonic averaging. -Save format: {'generator': vocoder.state_dict()} — same as the original BigVGAN -checkpoint so it can be loaded with SelVA BigVGAN Loader. +Tier-1 implementation: + 1. snake_alpha_only mode — only tune ~5K per-channel α parameters in + Snake/SnakeBeta activations. These control harmonic periodicity per + channel. With only 5K trainable params, the model physically cannot + reshape the spectrum enough to cause the "green smear". + 2. L2-SP anchor loss — penalizes parameter drift from pretrained values + (strictly better than weight decay, which anchors to zero). + 3. Frozen discriminator feature matching — if a BigVGAN discriminator + checkpoint is provided, the pretrained MPD+MRD networks are used as + fixed perceptual feature extractors. Feature matching loss penalizes + harmonic smearing directly without any GAN instability. + +Save format: {'generator': vocoder.state_dict()} — same as the original +BigVGAN checkpoint so it can be loaded with SelVA BigVGAN Loader. """ import random @@ -16,6 +31,7 @@ import threading from pathlib import Path import torch +import torch.nn as nn import torch.nn.functional as F import torchaudio import comfy.utils @@ -23,12 +39,133 @@ import folder_paths from .utils import SELVA_CATEGORY, get_device, soft_empty_cache -def _save_spectrogram(path, mel_tensor): - """Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep). - Normalises to [0, 255], flips frequency axis so low freqs are at the bottom, - and saves as a greyscale PNG with a simple viridis-like colour map. - """ +# --------------------------------------------------------------------------- +# Minimal MPD + MRD discriminators matching BigVGAN pretrained checkpoint keys +# --------------------------------------------------------------------------- + +def _get_pad(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class _DiscriminatorP(nn.Module): + """Multi-Period Discriminator sub-module (HiFi-GAN / BigVGAN style).""" + def __init__(self, period): + super().__init__() + self.period = period + from torch.nn.utils.parametrizations import weight_norm + norm = weight_norm + self.convs = nn.ModuleList([ + norm(nn.Conv2d(1, 32, (5, 1), (3, 1), (_get_pad(5, 1), 0))), + norm(nn.Conv2d(32, 128, (5, 1), (3, 1), (_get_pad(5, 1), 0))), + norm(nn.Conv2d(128, 512, (5, 1), (3, 1), (_get_pad(5, 1), 0))), + norm(nn.Conv2d(512, 1024,(5, 1), (3, 1), (_get_pad(5, 1), 0))), + norm(nn.Conv2d(1024,1024,(5, 1), 1, (_get_pad(5, 1), 0))), + ]) + self.conv_post = norm(nn.Conv2d(1024, 1, (3, 1), 1, (1, 0))) + + def forward(self, x): + fmap = [] + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + return fmap + + +class _MultiPeriodDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList([ + _DiscriminatorP(p) for p in [2, 3, 5, 7, 11] + ]) + + def forward(self, y): + fmaps = [] + for d in self.discriminators: + fmaps.extend(d(y)) + return fmaps + + +class _DiscriminatorR(nn.Module): + """Multi-Resolution Discriminator sub-module.""" + def __init__(self, fft_size, shift_size, win_length): + super().__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + from torch.nn.utils.parametrizations import weight_norm + norm = weight_norm + self.convs = nn.ModuleList([ + norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), + norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), + norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), + ]) + self.conv_post = norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) + + def spectrogram(self, x): + """x: [B, 1, T] → [B, 1, freq, time]""" + n, hop, win = self.fft_size, self.shift_size, self.win_length + window = torch.hann_window(win, device=x.device) + x = x.squeeze(1) # [B, T] + pad = (win - hop) // 2 + x = F.pad(x, (pad, pad + (win - hop) % 2), mode="reflect") + x = torch.stft(x, n, hop, win, window, center=False, return_complex=True) + x = x.abs().unsqueeze(1) # [B, 1, freq, time] + return x + + def forward(self, x): + fmap = [] + x = self.spectrogram(x) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, 0.1) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + return fmap + + +class _MultiResolutionDiscriminator(nn.Module): + def __init__(self): + super().__init__() + resolutions = [(1024, 120, 600), (2048, 240, 1200), (512, 50, 240)] + self.discriminators = nn.ModuleList([ + _DiscriminatorR(*r) for r in resolutions + ]) + + def forward(self, y): + fmaps = [] + for d in self.discriminators: + fmaps.extend(d(y)) + return fmaps + + +def _feature_matching_loss(fmaps_real, fmaps_gen): + """L1 between paired feature map lists (both already detach-safe for real).""" + loss = torch.zeros(1, device=fmaps_gen[0].device) + for fr, fg in zip(fmaps_real, fmaps_gen): + T = min(fr.shape[-1], fg.shape[-1]) + loss = loss + F.l1_loss(fg[..., :T], fr[..., :T].detach()) + return loss / len(fmaps_real) + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + +def _save_spectrogram(path, mel_tensor): + """Save mel spectrogram [1, n_mels, T] as a PNG using PIL (no matplotlib dep).""" try: from PIL import Image import numpy as np @@ -120,6 +257,10 @@ def _multi_resolution_stft_loss(pred_wav, target_wav, device): return loss / len(_STFT_RESOLUTIONS) +# --------------------------------------------------------------------------- +# Node +# --------------------------------------------------------------------------- + class SelvaBigvganTrainer: OUTPUT_NODE = True CATEGORY = SELVA_CATEGORY @@ -128,10 +269,10 @@ class SelvaBigvganTrainer: RETURN_NAMES = ("checkpoint_path",) OUTPUT_TOOLTIPS = ("Path to saved vocoder checkpoint — load with SelVA BigVGAN Loader.",) DESCRIPTION = ( - "Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips using " - "spectral losses (mel L1 + multi-resolution STFT L1). DiT and VAE stay frozen. " - "Supports both 16k (BigVGAN) and 44k (BigVGANv2) models. " - "Load the result with SelVA BigVGAN Loader." + "Fine-tunes the BigVGAN vocoder (mel→waveform) on BJ audio clips. " + "Default mode (snake_alpha_only) tunes only the ~5K Snake activation α " + "parameters — cannot cause harmonic smearing. Add a discriminator path " + "for perceptual feature matching loss. DiT and VAE stay frozen." ) @classmethod @@ -147,26 +288,53 @@ class SelvaBigvganTrainer: "default": "bigvgan_bj.pt", "tooltip": "Where to save the fine-tuned vocoder. Relative paths → ComfyUI output dir.", }), + "train_mode": (["snake_alpha_only", "all_params"], { + "default": "snake_alpha_only", + "tooltip": ( + "snake_alpha_only: only tune ~5K per-channel α parameters in Snake/SnakeBeta " + "activations. These control harmonic periodicity. Cannot cause spectral smearing. " + "all_params: tune all vocoder weights — set lambda_l2sp>0 to prevent drift." + ), + }), "steps": ("INT", { "default": 2000, "min": 100, "max": 50000, - "tooltip": "Training steps. 1000–2000 is a good first experiment.", + "tooltip": "Training steps. 1000–2000 is a good first experiment with snake_alpha_only.", }), "lr": ("FLOAT", { "default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-5, - "tooltip": "Learning rate. BigVGAN default is 1e-4.", + "tooltip": "Learning rate. 1e-4 for snake_alpha_only, 1e-5 for all_params.", }), "batch_size": ("INT", {"default": 4, "min": 1, "max": 32}), "segment_seconds": ("FLOAT", { "default": 1.0, "min": 0.25, "max": 4.0, "step": 0.25, "tooltip": "Audio segment length per training sample in seconds.", }), + "lambda_l2sp": ("FLOAT", { + "default": 1e-3, "min": 0.0, "max": 0.1, "step": 1e-4, + "tooltip": ( + "L2-SP anchor regularization: penalizes parameter drift from pretrained values. " + "0 = disabled. 1e-3 is good for snake_alpha_only. " + "Increase to 1e-2 for all_params to prevent catastrophic forgetting." + ), + }), "save_every": ("INT", {"default": 500, "min": 50, "max": 10000}), "seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}), }, + "optional": { + "discriminator_path": ("STRING", { + "default": "", + "tooltip": ( + "Optional path to BigVGAN discriminator checkpoint " + "(bigvgan_discriminator_optimizer.pt from the BigVGAN pretrained release). " + "When provided, frozen MPD+MRD feature matching replaces mel L1 — " + "the key fix for harmonic smearing. Leave empty to use mel+STFT losses only." + ), + }), + }, } - def train(self, model, data_dir, output_path, steps, lr, batch_size, - segment_seconds, save_every, seed): + def train(self, model, data_dir, output_path, train_mode, steps, lr, batch_size, + segment_seconds, lambda_l2sp, save_every, seed, discriminator_path=""): import traceback device = get_device() @@ -197,6 +365,14 @@ class SelvaBigvganTrainer: out_path = Path(folder_paths.get_output_directory()) / out_path out_path.parent.mkdir(parents=True, exist_ok=True) + disc_path = None + if discriminator_path and discriminator_path.strip(): + disc_path = Path(discriminator_path.strip()) + if not disc_path.is_absolute(): + disc_path = Path(folder_paths.get_output_directory()) / disc_path + if not disc_path.exists(): + raise FileNotFoundError(f"[BigVGAN] Discriminator checkpoint not found: {disc_path}") + # Find and pre-load audio clips segment_samples = int(segment_seconds * sample_rate) audio_files = [] @@ -227,8 +403,15 @@ class SelvaBigvganTrainer: raise RuntimeError( f"[BigVGAN] No usable clips found (need audio >= {segment_seconds}s)" ) - print(f"[BigVGAN] {len(clips)} clips ready segment={segment_seconds}s " - f"steps={steps} lr={lr} batch={batch_size}\n", flush=True) + + trainable_count = sum( + 1 for n, _ in vocoder.named_parameters() if "alpha" in n + ) if train_mode == "snake_alpha_only" else sum( + 1 for _ in vocoder.parameters() + ) + print(f"[BigVGAN] {len(clips)} clips ready mode={train_mode} " + f"segment={segment_seconds}s steps={steps} lr={lr} " + f"batch={batch_size} lambda_l2sp={lambda_l2sp}\n", flush=True) if strategy == "offload_to_cpu": feature_utils.to(device) @@ -259,8 +442,8 @@ class SelvaBigvganTrainer: vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - steps, lr, batch_size, save_every, seed, - out_path, pbar, + train_mode, steps, lr, batch_size, lambda_l2sp, + save_every, seed, out_path, disc_path, pbar, ) except Exception as e: _exc[0] = e @@ -275,11 +458,15 @@ class SelvaBigvganTrainer: return (_result[0],) +# --------------------------------------------------------------------------- +# Training worker +# --------------------------------------------------------------------------- + def _do_train(vocoder, mel_converter, clips, device, dtype, strategy, feature_utils, segment_samples, sample_rate, - steps, lr, batch_size, save_every, seed, - out_path, pbar): + train_mode, steps, lr, batch_size, lambda_l2sp, + save_every, seed, out_path, disc_path, pbar): """Execute training. Called in a fresh thread — no inference_mode active. Even though inference_mode is off here, tensors created in the calling @@ -372,7 +559,65 @@ def _do_train(vocoder, mel_converter, clips, if buf is not None: module._buffers[bname] = buf.clone() - optimizer = torch.optim.AdamW(vocoder.parameters(), lr=lr, betas=(0.8, 0.99)) + # ── Training mode: select which parameters to train ────────────────────── + if train_mode == "snake_alpha_only": + alpha_params = [] + for name, param in vocoder.named_parameters(): + if "alpha" in name: + param.requires_grad_(True) + alpha_params.append(param) + else: + param.requires_grad_(False) + n_trainable = sum(p.numel() for p in alpha_params) + print(f"[BigVGAN] snake_alpha_only: {n_trainable} trainable params " + f"({len(alpha_params)} alpha tensors)", flush=True) + trainable_params = alpha_params + else: # all_params + for param in vocoder.parameters(): + param.requires_grad_(True) + n_trainable = sum(p.numel() for p in vocoder.parameters()) + print(f"[BigVGAN] all_params: {n_trainable} trainable params", flush=True) + trainable_params = list(vocoder.parameters()) + + # ── L2-SP: cache reference parameter values (before any gradient steps) ── + ref_params = {} + if lambda_l2sp > 0.0: + for name, param in vocoder.named_parameters(): + if param.requires_grad: + ref_params[name] = param.data.clone().detach() + print(f"[BigVGAN] L2-SP anchor: {len(ref_params)} params λ={lambda_l2sp}", flush=True) + + # ── Optional: load pretrained discriminator for feature matching ────────── + mpd = mrd = None + if disc_path is not None: + try: + ckpt_d = torch.load(str(disc_path), map_location="cpu", weights_only=False) + mpd = _MultiPeriodDiscriminator() + mrd = _MultiResolutionDiscriminator() + # Try common key names used by different BigVGAN releases + for mpd_key in ("mpd", "discriminator_mpd", "MPD"): + if mpd_key in ckpt_d: + mpd.load_state_dict(ckpt_d[mpd_key], strict=False) + print(f"[BigVGAN] Loaded MPD from key '{mpd_key}'", flush=True) + break + for mrd_key in ("mrd", "discriminator_mrd", "MRD", "msd", "discriminator_msd"): + if mrd_key in ckpt_d: + mrd.load_state_dict(ckpt_d[mrd_key], strict=False) + print(f"[BigVGAN] Loaded MRD from key '{mrd_key}'", flush=True) + break + mpd.to(device).eval() + mrd.to(device).eval() + for p in mpd.parameters(): + p.requires_grad_(False) + for p in mrd.parameters(): + p.requires_grad_(False) + print(f"[BigVGAN] Frozen discriminators ready for feature matching", flush=True) + except Exception as e: + print(f"[BigVGAN] WARNING: Could not load discriminator ({e}), " + f"falling back to mel+STFT losses", flush=True) + mpd = mrd = None + + optimizer = torch.optim.AdamW(trainable_params, lr=lr, betas=(0.8, 0.99)) vocoder.train() try: @@ -396,24 +641,55 @@ def _do_train(vocoder, mel_converter, clips, pred_t = pred_wav[..., :T] target_t = target_wav[..., :T] - pred_mel = mel_converter(pred_t.squeeze(1)) # [B, n_mels, T_mel'] - T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) - mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) + # ── Compute loss ───────────────────────────────────────────────── + if mpd is not None and mrd is not None: + # Perceptual feature matching via frozen discriminators + with torch.no_grad(): + fmaps_real_mpd = mpd(target_t) + fmaps_real_mrd = mrd(target_t) + fmaps_gen_mpd = mpd(pred_t) + fmaps_gen_mrd = mrd(pred_t) + fm_loss = ( + _feature_matching_loss(fmaps_real_mpd, fmaps_gen_mpd) + + _feature_matching_loss(fmaps_real_mrd, fmaps_gen_mrd) + ) + # Keep a small mel loss for stable frequency alignment + pred_mel = mel_converter(pred_t.squeeze(1)) + T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) + mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) + primary_loss = 2.0 * fm_loss + 0.1 * mel_loss + loss_desc = f"fm={fm_loss.item():.4f} mel={mel_loss.item():.4f}" + else: + # Fallback: mel L1 + multi-resolution STFT L1 + pred_mel = mel_converter(pred_t.squeeze(1)) + T_mel = min(pred_mel.shape[-1], target_mel.shape[-1]) + mel_loss = F.l1_loss(pred_mel[..., :T_mel], target_mel[..., :T_mel]) + stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) + primary_loss = mel_loss + stft_loss + loss_desc = f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f}" - stft_loss = _multi_resolution_stft_loss(pred_t, target_t, device) + # ── L2-SP regularization ───────────────────────────────────────── + l2sp_loss = torch.zeros(1, device=device) + if lambda_l2sp > 0.0 and ref_params: + for name, param in vocoder.named_parameters(): + if name in ref_params and param.requires_grad: + l2sp_loss = l2sp_loss + F.mse_loss( + param, ref_params[name], reduction="sum" + ) + l2sp_loss = l2sp_loss * lambda_l2sp - loss = mel_loss + stft_loss + loss = primary_loss + l2sp_loss optimizer.zero_grad() loss.backward() - torch.nn.utils.clip_grad_norm_(vocoder.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step() pbar.update(1) if (step + 1) % max(1, steps // 20) == 0 or step == steps - 1: - print(f"[BigVGAN] {step+1}/{steps} " - f"mel={mel_loss.item():.4f} stft={stft_loss.item():.4f} " - f"total={loss.item():.4f}", flush=True) + l2sp_str = f" l2sp={l2sp_loss.item():.4e}" if lambda_l2sp > 0 else "" + print(f"[BigVGAN] {step+1}/{steps} {loss_desc}" + f" total={loss.item():.4f}{l2sp_str}", flush=True) if (step + 1) % save_every == 0 and (step + 1) < steps: step_path = out_path.parent / f"{out_path.stem}_step{step+1}{out_path.suffix}" diff --git a/nodes/selva_ditto_optimizer.py b/nodes/selva_ditto_optimizer.py new file mode 100644 index 0000000..cb0fdfd --- /dev/null +++ b/nodes/selva_ditto_optimizer.py @@ -0,0 +1,434 @@ +"""SelVA DITTO Optimizer. + +Inference-time noise optimization: optimizes the initial noise latent x_0 +using a style loss against BJ reference clips, backpropagating through the +ODE solver. All model weights remain frozen — only x_0 changes. + +Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179, +ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE. + +Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix) +against BJ reference clips. Runs entirely before the vocoder — optimization +only requires the DiT + VAE decoder, not BigVGAN. + +Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT +forward pass activations) instead of O(N steps). Backward recomputes each +step's activations on demand. +""" + +import dataclasses +import random +import threading +from pathlib import Path + +import torch +import torch.nn.functional as F +import torchaudio +import comfy.utils +import comfy.model_management +import folder_paths + +from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache +from .selva_sampler import SelvaSampler +from .selva_textual_inversion_trainer import _inject_tokens + + +def _load_wav(path): + """Load audio file to [channels, samples] float32 tensor.""" + try: + return torchaudio.load(str(path)) + except Exception: + pass + import soundfile as sf + data, sr = sf.read(str(path), dtype="float32", always_2d=True) + wav = torch.from_numpy(data.T) + return wav, sr + + +def _mel_style_loss(mel_gen, ref_mean, ref_gram): + """Style loss between generated mel and precomputed reference statistics. + + mel_gen: [1, n_mels, T] generated mel spectrogram (with grad) + ref_mean: [n_mels] mean spectrum of BJ reference clips (detached) + ref_gram: [n_mels, n_mels] Gram matrix of BJ reference clips (detached) + + Mean spectrum loss captures the spectral envelope (which harmonics are + boosted). Gram matrix loss captures timbral texture — covariance between + frequency bands — without requiring temporal alignment. + """ + m = mel_gen.squeeze(0) # [n_mels, T] + + # Mean spectrum loss + gen_mean = m.mean(dim=-1) # [n_mels] + loss_mean = F.l1_loss(gen_mean, ref_mean) + + # Gram matrix loss (texture, position-invariant) + gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels] + loss_gram = F.mse_loss(gram_gen, ref_gram) + + return loss_mean + 0.1 * loss_gram + + +class SelvaDittoOptimizer: + """DITTO inference-time noise optimization. + + Freezes all model weights and optimizes only the initial noise latent x_0 + to make the generated audio sound like the BJ reference clips. + No training data or gradient updates to the model — per-video per-run. + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("SELVA_MODEL",), + "features": ("SELVA_FEATURES",), + "prompt": ("STRING", { + "default": "", "multiline": True, + "tooltip": "Sound description. Leave empty to use features prompt.", + }), + "negative_prompt": ("STRING", { + "default": "", "multiline": False, + }), + "reference_dir": ("STRING", { + "default": "", + "tooltip": "Directory with BJ reference audio files (.wav/.flac/.mp3). " + "Reference mel statistics are precomputed from these once.", + }), + "n_opt_steps": ("INT", { + "default": 50, "min": 5, "max": 500, + "tooltip": "Gradient optimization steps on x_0. 50 is a good start; " + "each step requires ~2 DiT forward passes.", + }), + "opt_lr": ("FLOAT", { + "default": 0.1, "min": 0.001, "max": 2.0, "step": 0.01, + "tooltip": "Adam learning rate for x_0 optimization. " + "0.1 is the DITTO paper default.", + }), + "n_ode_steps": ("INT", { + "default": 10, "min": 5, "max": 50, + "tooltip": "Euler ODE steps run during each optimization iteration. " + "Lower = faster optimization (10–15 is a good trade-off). " + "Final generation always uses the steps parameter below.", + }), + "n_grad_steps": ("INT", { + "default": 5, "min": 1, "max": 50, + "tooltip": "ODE steps to differentiate through (truncated BPTT). " + "Higher = more accurate gradient, more VRAM. " + "Must be ≤ n_ode_steps. 5 is a good default.", + }), + "style_weight": ("FLOAT", { + "default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1, + "tooltip": "Weight of the BJ style loss. Increase to push harder toward " + "BJ style at the cost of coherence with the video.", + }), + "steps": ("INT", { + "default": 25, "min": 1, "max": 200, + "tooltip": "Euler steps for the final generation pass (after optimization).", + }), + "cfg_strength": ("FLOAT", { + "default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), + }, + "optional": { + "normalize": ("BOOLEAN", {"default": True}), + "target_lufs": ("FLOAT", { + "default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}), + }, + } + + RETURN_TYPES = ("AUDIO",) + RETURN_NAMES = ("audio",) + OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward BJ style.",) + FUNCTION = "optimize" + CATEGORY = SELVA_CATEGORY + DESCRIPTION = ( + "DITTO inference-time noise optimization (arXiv:2401.12179). " + "Optimizes the initial noise latent x_0 to match BJ reference clips " + "via mel statistics style loss, backpropagating through the ODE. " + "All model weights frozen — zero quality degradation risk." + ) + + def optimize(self, model, features, prompt, negative_prompt, + reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, + style_weight, steps, cfg_strength, seed, + normalize=True, target_lufs=-27.0): + import traceback + + device = get_device() + dtype = model["dtype"] + strategy = model["strategy"] + net_generator = model["generator"] + feature_utils = model["feature_utils"] + mel_converter = feature_utils.mel_converter + + # Validate variant match + feat_variant = features.get("variant") + if feat_variant is not None and feat_variant != model["variant"]: + raise ValueError( + f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. " + f"Re-run Feature Extractor." + ) + + if not prompt or not prompt.strip(): + prompt = features.get("prompt", "") + + # Resolve duration and seq_cfg + duration = features.get("duration", 0) + if duration <= 0: + raise ValueError("[DITTO] Features contain no duration field.") + seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration) + sample_rate = seq_cfg.sampling_rate + + # Load and precompute reference mel statistics + ref_dir = Path(reference_dir.strip()) + if not ref_dir.is_absolute(): + ref_dir = Path(folder_paths.models_dir) / ref_dir + if not ref_dir.exists(): + raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}") + + ref_files = [] + for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"): + ref_files.extend(ref_dir.rglob(ext)) + if not ref_files: + raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}") + + print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True) + mel_converter.to(device) + + ref_mels = [] + with torch.no_grad(): + for rf in ref_files[:32]: # cap at 32 for speed + try: + wav, sr = _load_wav(rf) + if wav.shape[0] > 1: + wav = wav.mean(0, keepdim=True) + if sr != sample_rate: + wav = torchaudio.functional.resample(wav, sr, sample_rate) + wav = wav.squeeze(0).to(device, dtype) + mel = mel_converter(wav.unsqueeze(0)) # [1, n_mels, T] + ref_mels.append(mel) + except Exception as e: + print(f" [DITTO] Skip {rf.name}: {e}", flush=True) + + if not ref_mels: + raise RuntimeError("[DITTO] No usable reference clips.") + + # Precompute reference statistics (done once — detached, no grad) + with torch.no_grad(): + all_means = torch.stack([m.squeeze(0).mean(dim=-1) for m in ref_mels]) + ref_mean = all_means.mean(0) # [n_mels] + + all_grams = [] + for m in ref_mels: + M = m.squeeze(0) # [n_mels, T] + all_grams.append((M @ M.T) / M.shape[-1]) + ref_gram = torch.stack(all_grams).mean(0) # [n_mels, n_mels] + + print(f"[DITTO] Reference stats computed from {len(ref_mels)} clips " + f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} " + f"grad_steps={n_grad_steps}", flush=True) + + if strategy == "offload_to_cpu": + net_generator.to(device) + feature_utils.to(device) + soft_empty_cache() + + pbar = comfy.utils.ProgressBar(n_opt_steps + steps) + + _result = [None] + _exc = [None] + + def _worker(): + try: + _result[0] = _do_optimize( + net_generator, feature_utils, mel_converter, + features, prompt, negative_prompt, + ref_mean, ref_gram, + seq_cfg, sample_rate, device, dtype, + n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, + style_weight, steps, cfg_strength, seed, + normalize, target_lufs, pbar, + ) + except Exception as e: + _exc[0] = e + traceback.print_exc() + + t = threading.Thread(target=_worker, daemon=True) + t.start() + t.join() + + if strategy == "offload_to_cpu": + net_generator.to(get_offload_device()) + feature_utils.to(get_offload_device()) + soft_empty_cache() + + if _exc[0] is not None: + raise _exc[0] + return (_result[0],) + + +def _do_optimize(net_generator, feature_utils, mel_converter, + features, prompt, negative_prompt, + ref_mean, ref_gram, + seq_cfg, sample_rate, device, dtype, + n_opt_steps, opt_lr, n_ode_steps, n_grad_steps, + style_weight, steps, cfg_strength, seed, + normalize, target_lufs, pbar): + """Optimization loop — runs in a fresh thread (no inference_mode active).""" + + # Strip inference flags from ref stats (came from main thread) + ref_mean = ref_mean.clone().detach() + ref_gram = ref_gram.clone().detach() + + torch.manual_seed(seed) + + clip_f = features["clip_features"].to(device, dtype) + sync_f = features["sync_features"].to(device, dtype) + + net_generator.update_seq_lengths( + latent_seq_len=seq_cfg.latent_seq_len, + clip_seq_len=clip_f.shape[1], + sync_seq_len=sync_f.shape[1], + ) + + with torch.no_grad(): + text_clip = feature_utils.encode_text_clip([prompt]) + + neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ + if negative_prompt.strip() else None + + conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) + empty_conditions = net_generator.get_empty_conditions( + bs=1, negative_text_features=neg_text_clip + ) + + # Initial noise — x_0 is the parameter we optimize + x0_init = torch.randn( + 1, seq_cfg.latent_seq_len, net_generator.latent_dim, + device=device, dtype=dtype, + ) + x0 = torch.nn.Parameter(x0_init.clone()) + optimizer = torch.optim.Adam([x0], lr=opt_lr) + + # n_grad_steps must not exceed n_ode_steps + n_grad_steps = min(n_grad_steps, n_ode_steps) + n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient + + ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype) + + print(f"[DITTO] Optimizing x_0 " + f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True) + + # Freeze all model weights (double-check — should already be frozen at inference) + net_generator.requires_grad_(False) + feature_utils.requires_grad_(False) + mel_converter.requires_grad_(False) + + for opt_step in range(n_opt_steps): + comfy.model_management.throw_exception_if_processing_interrupted() + + # ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ── + # This is cheaper than checkpointing all steps, at the cost of an + # approximate (truncated) gradient. The gradient still flows through + # n_grad_steps steps, which is sufficient for meaningful x_0 updates. + with torch.no_grad(): + x = x0 + for i in range(n_free_steps): + t = ts[i] + dt = ts[i + 1] - t + flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) + x = x + dt * flow + + # Detach and re-leaf so backward only goes n_grad_steps deep. + # We treat x_k as a new leaf but seed it from x_0's value — so at + # opt step 0 the gradient is a true n_grad_steps truncated BPTT, + # and x_0 gets updated via x_k's dependence on x_0 through the + # no-grad prefix (approximation: gradient doesn't flow through prefix). + # + # Richer alternative: full checkpointing through all steps (uncomment + # the checkpoint block below and remove the no-grad prefix). + x = x.detach().requires_grad_(True) + + # ── Phase 2: run last n_grad_steps with gradient + checkpointing ── + for i in range(n_free_steps, n_ode_steps): + t = ts[i] + dt = ts[i + 1] - t + + # Gradient checkpointing: recompute forward during backward, + # avoiding storage of DiT activations for each step. + def _ode_step(x_in, t=t): + return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength) + + flow = torch.utils.checkpoint.checkpoint( + _ode_step, x, use_reentrant=False + ) + x = x + dt * flow + + # ── Decode to mel (no vocoder — cheap) ────────────────────────────── + x_unnorm = net_generator.unnormalize(x) + mel_gen = feature_utils.decode(x_unnorm) # latent → mel [1, n_mels, T] + + # ── Style loss ─────────────────────────────────────────────────────── + loss = style_weight * _mel_style_loss(mel_gen, ref_mean, ref_gram) + + optimizer.zero_grad() + loss.backward() + + # Propagate gradient from x (grad_fn leaf) back to x_0. + # x was detached from x_0, so we manually transfer the gradient: + # the no-grad prefix is an approximation — skip this if doing full + # checkpointing (x would have grad_fn pointing back to x_0). + # Here x.grad is the gradient w.r.t. x at step n_free_steps; + # we directly add it to x_0.grad as an approximation. + if x.grad is not None: + if x0.grad is None: + x0.grad = x.grad.clone() + else: + x0.grad.add_(x.grad) + + torch.nn.utils.clip_grad_norm_([x0], 1.0) + optimizer.step() + + pbar.update(1) + + if (opt_step + 1) % max(1, n_opt_steps // 10) == 0: + print(f"[DITTO] {opt_step+1}/{n_opt_steps} loss={loss.item():.4f}", flush=True) + + # ── Final generation with optimized x_0 ───────────────────────────────── + print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True) + + with torch.no_grad(): + fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype) + x = x0.detach() + for i in range(steps): + comfy.model_management.throw_exception_if_processing_interrupted() + t = fm_ts[i] + dt = fm_ts[i + 1] - t + flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength) + x = x + dt * flow + pbar.update(1) + + x1_unnorm = net_generator.unnormalize(x) + spec = feature_utils.decode(x1_unnorm) + audio = feature_utils.vocode(spec) + + print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}", + flush=True) + + audio = audio.float() + if audio.dim() == 2: + audio = audio.unsqueeze(1) + elif audio.dim() == 3 and audio.shape[1] != 1: + audio = audio.mean(dim=1, keepdim=True) + + if normalize: + target_rms = 10 ** (target_lufs / 20.0) + rms = audio.pow(2).mean().sqrt().clamp(min=1e-8) + audio = audio * (target_rms / rms) + peak = audio.abs().max().clamp(min=1e-8) + if peak > 1.0: + audio = audio / peak + + print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) + return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)