2 Commits

Author SHA1 Message Date
Ethanfel c9550ce693 experiment: add crop_rect option — rect bbox crop without squarification
crop_rect crops frames to a rectangle around the mask bounding box
(with margin applied independently on each axis) before resizing.
The model still stretches the result to 384×384 / 224×224, but only
sees the region around the target element — simpler than crop_to_mask.

- Refactored _compute_mask_bbox with square= param (True = existing
  square logic, False = rect with per-axis margin)
- Empty mask fallback for rect mode returns the full frame (no-op)
- crop_to_mask takes priority over crop_rect when both are enabled
- crop_margin is shared between both crop modes
- Hash includes crop_rect; crop_margin hashed when either crop is active
- Log line reports mode (square/rect) alongside bbox dimensions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 13:04:46 +02:00
Ethanfel f3cabcad90 experiment: crop-to-mask mode on feature extractor
Instead of squishing the full frame to a square, optionally crops a square
region around the mask bounding box (union across all frames) before resizing.
Preserves aspect ratio of the subject and gives the model a focused,
undistorted view.

New optional inputs on SelVA Feature Extractor:
- crop_to_mask (BOOLEAN, default false) — enable the crop path
- crop_margin (FLOAT 0–1, default 0.1) — padding around the bbox as a
  fraction of the bounding box side

_compute_mask_bbox: resizes mask to frame resolution, takes union over
all mask frames, expands to square + margin, shifts into frame bounds to
preserve square shape before clamping. Falls back to center square crop
if mask is empty.

Bbox is computed once from the original-resolution mask and reused for
both CLIP (384px) and sync (224px) frame sequences. Combine with
mask_clip/mask_sync for full background suppression on top of the crop.
Cache hash includes crop_to_mask and crop_margin when mask is connected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-05 12:52:03 +02:00
45 changed files with 118 additions and 11164 deletions
-459
View File
@@ -1,459 +0,0 @@
# LoRA Training for SelVA
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
---
## Overview
Training is split into two steps:
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 34 GB of VRAM.
---
## Requirements
Same environment as SelVA inference. Additional Python packages:
```
torchaudio
soundfile
```
---
## Step 1 — Prepare the dataset
### 1.1 Video format
The feature extractor accepts any input but internally resamples frames to fixed square resolutions (384×384 for CLIP, 224×224 for Synchformer). Both encoders were trained on standard video datasets — predominantly landscape footage. This has two practical implications:
**Aspect ratio** — use **16:9 landscape** whenever possible. Portrait clips (9:16) are mechanically supported but the bicubic stretch into square distorts the image relative to the encoders' training distribution, which can degrade sync feature quality. If your source is portrait, center-crop to square before extraction. Square (1:1) is also fine.
**Resolution** — anything ≥ 480p is sufficient. The extractor downscales to 384px and 224px regardless of source resolution; higher resolution adds no benefit.
**Frame rate** — any. Connect `VHS_VIDEOINFO` from VHS LoadVideo to the feature extractor so fps is read automatically from the file instead of being entered manually.
| Format | Recommendation |
|---|---|
| Aspect ratio | 16:9 landscape (preferred) or 1:1 square |
| Resolution | ≥ 480p (720p+ is fine, no upper limit that matters) |
| Frame rate | Any — set via VHS_VIDEOINFO |
| Portrait (9:16) | Center-crop to square before extraction |
### 1.2 Extract visual features in ComfyUI
For each video clip you want to train on:
1. Load the video with a VHS LoadVideo node.
2. Connect it to **SelVA Feature Extractor**.
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
### Prompt guide
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
**Good prompts are specific about:**
- The sound source (what object is making the sound)
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
- The action producing the sound (if applicable)
| Sound | Weak prompt | Strong prompt |
|---|---|---|
| Dog bark | `dog` | `a large dog barking loudly` |
| Footsteps | `walking` | `heavy boots on a wooden floor` |
| Water | `water` | `water dripping into a metal bucket` |
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
| Door | `door` | `a heavy wooden door slamming shut` |
**Rules of thumb:**
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
### Masking note
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
### 1.3 Collect clean audio
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
```
dataset/my_sound/
dog_bark_001.npz ← from SelVA Feature Extractor
dog_bark_001.wav ← clean isolated audio recording
dog_bark_002.npz
dog_bark_002.wav
dog_bark_003.npz
dog_bark_003.wav
```
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
### 1.4 Optional: prompts.txt
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
```
# One line per file: filename: prompt text
dog_bark.npz: a large dog barking aggressively
dog_bark_001.npz: a dog barking in the distance
```
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
---
## Step 2 — Train
### Option A — SelVA LoRA Trainer node (ComfyUI)
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
```
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
```
### Option B — Command line
```bash
python train_lora.py \
--data_dir dataset/my_sound \
--output_dir lora_output/my_sound \
--variant large_44k \
--selva_dir /path/to/ComfyUI/models/selva \
--rank 16 \
--steps 4000 \
--batch_size 4 \
--lr 1e-4
```
The script will:
1. Load the VAE, CLIP text encoder, and generator.
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
3. Train LoRA adapters for the specified number of steps.
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
---
## CLI Reference
| Argument | Default | Description |
|---|---|---|
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
| `--selva_dir` | required | Path to SelVA model weights directory |
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
| `--lr` | `1e-4` | Learning rate |
| `--steps` | `2000` | Total training steps |
| `--warmup_steps` | `100` | Linear LR warmup steps |
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
| `--save_every` | `500` | Save a checkpoint every N steps |
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
| `--seed` | `42` | Random seed |
| `--timestep_mode` | `uniform` | Timestep sampling: `uniform`, `logit_normal`, or `curriculum` |
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` / `curriculum` |
| `--curriculum_switch` | `0.6` | Fraction of steps to use logit_normal before switching to uniform. Only with `curriculum` |
| `--lora_dropout` | `0.0` | Dropout on the LoRA path only. `0.05``0.1` helps regularize on small datasets |
| `--lora_plus_ratio` | `1.0` | LoRA+ LR ratio: `lr_B = lr × ratio`. `1.0` = standard LoRA, `16.0` = LoRA+ |
---
## Step 3 — Load the adapter in ComfyUI
Connect **SelVA LoRA Loader** between the model loader and the sampler:
```
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
```
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
| Input | Description |
|---|---|
| `model` | SELVA_MODEL from the model loader |
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
---
## Tuning Guide
### Clip length
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
**Practical recommendations:**
| Sound type | Clip strategy |
|---|---|
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 23 times per clip |
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
| Sound with a clear onset (explosion, splash) | Put the onset at ~12 s from the start, not at 0 s — gives the model context |
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
### How many clips do I need?
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
| Dataset size | Scenario | Expected result |
|---|---|---|
| **510 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
| **1530 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
| **3060 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
| **60150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
| **150300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
### Batch size
| Batch size | VRAM (large_44k) | Use case |
|---|---|---|
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
| `8` | ~15 GB | Better convergence on larger datasets |
| `16` | ~20 GB | Best gradient quality when VRAM allows |
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
### Rank
| Rank | Use case |
|---|---|
| `8` | Fine details on a sound the model already knows well |
| `16` | Default — good balance of capacity and VRAM |
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
Higher rank increases VRAM usage and overfitting risk on small datasets.
### Steps
With `batch_size=4` as the default, these are rough guidelines:
| Dataset size | Recommended steps |
|---|---|
| 1020 clips | 20004000 |
| 2050 clips | 40008000 |
| 50+ clips | 600015000 |
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
### Learning rate
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
### Target layers
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
```bash
--target attn.qkv linear1
```
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
### Timestep sampling mode
Controls how training timesteps are sampled at each step.
`uniform` (default) samples all timesteps equally — equivalent to original MMAudio training.
`logit_normal` concentrates more steps near t=0.5 via `sigmoid(N(0, σ))`. This is the semantically rich mid-noise region. Consistently reaches a lower loss floor but the perceptual improvement on small datasets is marginal.
`curriculum` uses logit_normal for the first `curriculum_switch` fraction of steps (default 60%), then switches to uniform for the remainder. The motivation: logit_normal accelerates early structure learning but undertrains the high-t boundary region; uniform then fills in the fine detail. A switch message is logged when the transition happens.
| Mode | When to use |
|---|---|
| `uniform` (default) | Baseline — safe, equivalent to original training |
| `logit_normal` | When you want a lower loss floor; marginal on small datasets |
| `curriculum` | Experimental — may improve convergence quality on small datasets |
The `logit_normal_sigma` parameter controls the width of the logit-normal distribution (used by both `logit_normal` and the first phase of `curriculum`):
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
- σ=0.5: sharper peak, less coverage of extremes
- σ=2.0: broader, approaches uniform
### LoRA dropout
`lora_dropout` applies dropout to the input of the LoRA path (not the frozen base linear). It regularizes the low-rank update without disturbing pretrained weights — helpful on small datasets where the LoRA would otherwise overfit to the training clips.
| Value | Use case |
|---|---|
| `0.0` (default) | No regularization — fine for 30+ clips |
| `0.05` | Light regularization — recommended starting point on 1020 clips |
| `0.1` | Stronger regularization — use if loss plateaus but audio is still noisy |
Dropout is not saved in the adapter file — it only affects training. Loading the adapter at inference does not require setting dropout.
### LoRA+ (asymmetric learning rate)
`lora_plus_ratio` splits the learning rate between LoRA A and B matrices: `lr_B = lr × ratio`. The B matrix is the output-side projection and benefits from a higher LR. Setting ratio to 16 enables the LoRA+ scheme from arXiv:2402.12354.
| Ratio | Effect |
|---|---|
| `1.0` (default) | Standard LoRA — identical A and B learning rates |
| `4.0` | Mild asymmetry |
| `16.0` | LoRA+ — faster convergence, especially on early steps |
LoRA+ is orthogonal to dropout and curriculum sampling — all three can be combined.
### Adapter strength at inference
| Strength | Effect |
|---|---|
| `0.50.7` | Conservative — blends adapter with base model, less noise |
| `1.0` | Full adapter strength (default) |
| `>1.0` | Exaggerated effect, may introduce artifacts |
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.60.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
### Loss interpretation
A typical loss curve:
- Starts around `0.81.0`
- Should reach `0.550.65` after convergence on a clean sound class with 1030 clips
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
- Below `0.1` on a small dataset means overfitting
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
### Precision
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
---
## Output files
```
lora_output/my_sound/
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
adapter_step01000.pt
...
adapter_final.pt ← final adapter with embedded metadata (inference only)
meta.json ← human-readable metadata
sample_step00500.wav ← quick eval sample at each checkpoint
loss_raw.png ← raw loss curve
loss_smoothed.png ← EMA-smoothed loss curve
```
`adapter_final.pt` format:
```python
{
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
"meta": {
"variant": "large_44k",
"rank": 16,
"alpha": 16.0,
"target": ["attn.qkv"],
"steps": 2000
}
}
```
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
---
## Troubleshooting
**`No layers matched target=...`**
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
**`No .npz files found in ...`**
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
**`No audio file found for clip.npz`**
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
**The sound is audible but there is white noise on top**
Lower the adapter strength to `0.60.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
**LoRA appears to have no effect**
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
**Loss does not decrease**
- Increase `batch_size` for more stable gradients.
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
- Check that the audio files are clean and actually contain the target sound.
- Check that the `.npz` features were extracted with a relevant prompt.
**Loss explodes or NaN**
- Lower the learning rate (`5e-5`).
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
**Loss plateaus early (above 0.7)**
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
---
## Observations (work in progress)
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
### Precision and batch size
| Config | Smoothed loss at step 2000 | Notes |
|---|---|---|
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 60008000 at ~0.59 |
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
### logit_normal vs uniform
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
### White noise
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
- Too few clips for the model to confidently predict the target sound
- Imprecise extraction prompts producing unfocused sync features
- Missing mask when multiple objects are in frame
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.03.5 or adapter strength to 0.60.7 helps at inference.
+8 -254
View File
@@ -58,7 +58,7 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| Input | Description | | Input | Description |
|-------|-------------| |-------|-------------|
| `model` | From SelVA Model Loader (or any loader/loader chain) | | `model` | From SelVA Model Loader |
| `features` | From SelVA Feature Extractor | | `features` | From SelVA Feature Extractor |
| `prompt` | Text description — leave empty to use the prompt stored in features | | `prompt` | Text description — leave empty to use the prompt stored in features |
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) | | `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
@@ -66,261 +66,22 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
| `steps` | Sampling steps (default: 25) | | `steps` | Sampling steps (default: 25) |
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) | | `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
| `seed` | RNG seed | | `seed` | RNG seed |
| `normalize` | RMS-normalize output to `target_lufs` (default: true) | | `normalize` | Peak-normalize output to [-1, 1] (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
**Output:** `AUDIO` **Output:** `AUDIO`
--- ---
### SelVA LoRA Loader ## Workflow
Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
**Output:** `model` (SELVA_MODEL with adapter injected)
---
### SelVA LoRA Trainer
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
---
### SelVA LoRA Scheduler
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `experiments_file` | Path to JSON sweep config |
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Skip Experiment
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
**Output:** `flag_path` (STRING)
---
### SelVA LoRA Evaluator
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
---
### SelVA Dataset Browser
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
**Outputs:** video path, audio path, frames directory, label, total count
---
### SelVA VAE Roundtrip
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `audio` | AUDIO to test |
**Output:** `audio_reconstructed` (AUDIO)
---
### SelVA HF Smoother
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
**Output:** `audio` (AUDIO)
---
### SelVA Spectral Matcher
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
**Output:** `audio` (AUDIO)
---
### SelVA Textual Inversion Trainer
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
---
### SelVA Textual Inversion Loader
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
---
### SelVA TI Scheduler
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
---
### SelVA Activation Steering Extractor
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with `.npz` feature files |
| `output_path` | Where to save `steering_vectors.pt` |
| `n_samples` | Clips to average over (default: 16) |
| `seed` | RNG seed |
**Output:** `steering_path` (STRING)
---
### SelVA Activation Steering Loader
Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
**Output:** `steering_vectors` (STEERING_VECTORS)
---
### SelVA BigVGAN Trainer
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `data_dir` | Directory with target-style audio files (searched recursively) |
| `output_path` | Where to save the fine-tuned vocoder `.pt` |
| `train_mode` | `snake_alpha_only` (default) or `all_params` |
| `steps` | Training steps (default: 2000) |
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
| `batch_size` | Clips per step (default: 4) |
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
| `save_every` | Checkpoint interval in steps (default: 500) |
| `seed` | RNG seed |
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
---
### SelVA BigVGAN Loader
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL from Model Loader |
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
---
### SelVA DITTO Optimizer
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
| Input | Description |
|-------|-------------|
| `model` | SELVA_MODEL |
| `features` | From SelVA Feature Extractor |
| `prompt` | Sound description (leave empty to use features prompt) |
| `negative_prompt` | Sounds to suppress |
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
| `steps` | Euler steps for the final generation pass (default: 25) |
| `cfg_strength` | CFG scale (default: 4.5) |
| `seed` | RNG seed |
| `normalize` | *(optional)* RMS normalize output (default: true) |
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
**Output:** `AUDIO`
---
## Workflows
### Basic generation
``` ```
VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
│ (video_info) │ (video_info) ─► (fps auto)
│ (features) ──────────────────────────────────►│ │ (features) ────────────────────────────────────►│
│ (prompt) ────────────────────────────────────►│ │ (prompt) ──────────────────────────────────────►│
``` ```
### DITTO style transfer (recommended first approach) Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction.
```
SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
SelVA Feature Extractor ──(features)────────────────────────────────────►│
(prompt) ──────────────────────────────────────►│
BJ reference_dir ───────────────────────────────────────────────────────►│
```
No training required. Each run optimizes x₀ independently for the current video and reference set.
### Vocoder fine-tuning
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
▲ ▲
checkpoint .pt SelVA Feature Extractor
```
### LoRA training
See [LORA_TRAINING.md](LORA_TRAINING.md).
--- ---
@@ -366,15 +127,8 @@ The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available,
--- ---
## Style Transfer
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
---
## Credits ## Credits
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training - [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework - [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis - [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
-158
View File
@@ -1,158 +0,0 @@
# Style Transfer for SelVA
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
---
## Why standard fine-tuning is hard
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~1015 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.210.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
The approaches below are ordered by expected quality and ease of use.
---
## Tier 1 — DITTO (recommended first try)
**Node: SelVA DITTO Optimizer**
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
**How it works:**
For each video clip you want to process:
1. Run SelVA Feature Extractor as usual.
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
```
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
SelVA Feature Extractor ──(features)────────────────────────►│
(prompt) ──────────────────────────►│
BJ clips ───────────────────────────(reference_dir) ─────────►│
```
**Tuning guide:**
| Parameter | Starting value | When to adjust |
|---|---|---|
| `n_opt_steps` | 50 | Increase to 100200 if style shift is too subtle |
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
| `style_weight` | 1.0 | Increase to 25 for stronger BJ character; watch for incoherence |
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~46 GB additional VRAM over baseline inference.
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 515 minutes depending on GPU.
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
---
## Tier 2 — Vocoder Fine-tuning
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
### Why plain mel L1 loss fails
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 38 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
### `snake_alpha_only` mode (default, recommended)
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
### `all_params` mode with discriminator
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
### Workflow
```
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
BJ audio clips ──(data_dir)──►│
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor
```
### Tuning guide
| Parameter | Default | Notes |
|---|---|---|
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
| `steps` | 2000 | 10002000 for snake_alpha_only; 30005000 for all_params |
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
| `batch_size` | 4 | 48 for stable gradients |
| `segment_seconds` | 1.0 | 12 s segments recommended |
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
---
## Tier 3 — DITTO + Vocoder (combined)
Stack both:
```
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
▲ ▲
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
```
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
---
## What doesn't work (and why)
### Standard LoRA
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
### Textual inversion
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
### Activation steering (global mean difference)
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 36 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
---
## Reference dataset preparation
Use the same audio clips for both DITTO and vocoder fine-tuning:
- **Minimum:** 2030 clips. DITTO works from 5+; vocoder benefits from 40+.
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
-170
View File
@@ -1,170 +0,0 @@
# Audio Dataset Pipeline for Generative Model Training
Research notes on audio cleaning, augmentation, and quality metrics for LoRA fine-tuning of MMAudio/SelVA. Based on papers and tooling survey (April 2026).
---
## Core Principle
Augmentation for generative models ≠ augmentation for classifiers.
The goal is **not invariance** — it is expanding the training manifold so the model learns the distribution of a sound rather than memorizing a fixed set of waveforms.
With 10 clips, velocity field collapse (arXiv:2410.23594) is mathematically expected: the flow-matching model memorizes the training trajectories instead of generalizing. More diverse data is the only real fix.
---
## Recommended Pipeline
### Step 1 — Quality Screening
```python
# Clipping check
clip_ratio = np.sum(np.abs(audio) >= 0.99) / len(audio) # flag if > 0.1%
# DC offset check + removal
dc = np.mean(audio)
audio -= dc
# LUFS normalization to -14 LUFS (essential for training consistency)
# pip install pyloudnorm
import pyloudnorm as pyln
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(audio)
audio = pyln.normalize.loudness(audio, loudness, -14.0)
# Or via ffmpeg: ffmpeg -af loudnorm=I=-14:LRA=7:TP=-1
# DNSMOS quality gate (discard if OVRL < 3.5 for training; < 2.5 is unusable)
# from Microsoft DNS-Challenge repo
```
### Step 2 — Cleaning
| Tool | Install | Use |
|---|---|---|
| **AudioSep** | `pip install audiosep` | Isolate target sound from background — most impactful tool |
| **noisereduce** | `pip install noisereduce` | Light stationary/non-stationary denoising, preserves character |
| **librosa** | `pip install librosa` | Silence trimming: `librosa.effects.trim(audio, top_db=30)` |
| **torchaudio.transforms.Fade** | (torchaudio) | Prevent click artifacts at clip edges |
| **DeepFilterNet** | `pip install deepfilternet` | Heavy denoising — good for speech, may alter tonal sounds |
**AudioSep usage:**
```python
from audiosep import AudioSep
model = AudioSep.from_pretrained("audio-agi/audiosep")
# ~1.5 GB checkpoint, ~4 GB VRAM
model.inference(audio_path, "a dog barking loudly", output_path)
```
### Step 3 — Waveform Augmentation (10 clips → 50100)
Apply stochastically per clip:
| Transform | Params | Notes |
|---|---|---|
| **PitchShift** | ±13 semitones | 3 variants per clip. Limit to ±1 st for tonal/pitched sounds |
| **ApplyImpulseResponse** | 5 different RIRs | 5 variants per clip — EchoThief (~150 free IRs) or pyroomacoustics |
| **LoudnessNormalization** | ±2 dB random | Subtle level variation |
| **SevenBandParametricEQ** | ±3 dB | Gentle spectral variation |
| **TimeStretch** | 0.91.1× only | Do NOT use 2× to pad short clips — breaks video sync |
```python
# pip install audiomentations pedalboard pyroomacoustics
import audiomentations as A
augment = A.Compose([
A.PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
A.ApplyImpulseResponse(ir_paths="path/to/irs/", p=0.5),
A.SevenBandParametricEQ(min_gain_db=-3, max_gain_db=3, p=0.3),
A.LoudnessNormalization(min_lufs=-16, max_lufs=-12, p=0.5),
A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3),
])
audio_aug = augment(samples=audio, sample_rate=sr)
```
**RIR sources:**
- EchoThief: ~150 free real-world IRs (churches, caves, parking garages)
- pyroomacoustics: synthetic room simulation, fully controllable
### Step 4 — Latent Augmentation (at training time)
After VAE encoding:
**Latent mixup** between same-category pairs:
```python
# Mix latents BEFORE flow-matching noise is added
# Only mix clips from the same sound category — cross-category mixing produces garbage
lam = torch.distributions.Beta(0.4, 0.4).sample()
z_mix = lam * z1 + (1 - lam) * z2
```
With 10 clips: C(10,2) = 45 possible pairs → significant expansion without new recordings.
**Small Gaussian noise:**
```python
z_noised = z + torch.randn_like(z) * 0.02 * z.std()
```
Prevents trivial memorization of exact latent coordinates.
MusicLDM (arXiv:2308.01546) shows latent mixup > waveform mixup for generative quality.
---
## Transforms to AVOID for Generative Training
| Transform | Why |
|---|---|
| ClippingDistortion, BitCrush, TanhDistortion, Mp3Compression | Model learns the artifact |
| Reverse | Breaks temporal structure for video-to-audio task |
| TimeMask (creating silence gaps) | Unnatural — model learns to produce silence |
| TimeStretch > 1.3× | Phase vocoder artifacts become part of the target distribution |
| Heavy background noise (< 15 dB SNR) | Model learns to reproduce the noise |
---
## Quality Metrics
| Metric | Tool | Threshold |
|---|---|---|
| DNSMOS P.835 (SIG/BAK/OVRL) | Microsoft DNS-Challenge | OVRL > 3.5 for training |
| LUFS | pyloudnorm | Normalize all clips to -14 LUFS |
| WADA-SNR | (standalone) | No-reference SNR estimate |
| Clipping ratio | NumPy | Flag if > 0.1% of samples at ±0.99 |
---
## Tool Reference
| Tool | Install | Purpose |
|---|---|---|
| audiomentations | `pip install audiomentations` | Primary augmentation library |
| pedalboard | `pip install pedalboard` | Higher quality pitch shift, IR convolution |
| AudioSep | `pip install audiosep` | Source separation / isolation |
| noisereduce | `pip install noisereduce` | Non-stationary denoising |
| DeepFilterNet | `pip install deepfilternet` | Heavy denoising (speech-optimized) |
| pyloudnorm | `pip install pyloudnorm` | LUFS normalization |
| Silero VAD | `pip install silero-vad` | Voice/silence detection |
| pyroomacoustics | `pip install pyroomacoustics` | Synthetic RIR generation |
---
## Integration with PrismAudio / SelVA
No established ComfyUI audio preprocessing ecosystem as of early 2026. Build thin wrapper nodes around the tools above. PrismAudio already has all required patterns (subprocess isolation, AUDIO type transport).
**Target node set:**
- `SelVA Dataset Cleaner` — wraps noisereduce + LUFS normalization + trim + DNSMOS gate
- `SelVA Dataset Augmenter` — wraps audiomentations Compose pipeline
Steps 13 are preprocessing (run once before feature extraction).
Step 4 (latent mixup) is a training loop modification — integrate into `selva_lora_trainer.py`.
---
## Key Papers
| Paper | ArXiv | Finding |
|---|---|---|
| MusicLDM | 2308.01546 | Latent mixup > waveform mixup for generative quality |
| EDMSound | 2311.08667 | Memorization documented — same failure mode as 10-clip training |
| Synthio | 2410.02056 | Synthetic audio as augmentation data (ICLR 2025) |
| HunyuanVideo-Foley | 2508.16930 | V2A data pipeline at scale (100K hrs) |
| FM memorization | 2410.23594 | Velocity field collapse theory — proves early overfitting on small datasets |
-184
View File
@@ -1,184 +0,0 @@
# AudioX vs SelVA — Evaluation
AudioX (arXiv:2503.10522, ICLR 2026) is a unified multimodal audio generation model from HKUST.
This document compares it against SelVA/MMAudio and assesses the cost of adding it to PrismAudio.
---
## Quick Decision Guide
| Situation | Use |
|---|---|
| Video → realistic sound effects | **SelVA** — faster, purpose-built, MIT license |
| Music generation from video or text | **AudioX** — SelVA cannot do this |
| Audio inpainting / music continuation | **AudioX** — SelVA cannot do this |
| LoRA fine-tuning on a custom sound | **SelVA** — full training infrastructure already exists |
| Variable output duration | **AudioX** — SelVA is fixed at 8 s |
| Inference speed matters | **SelVA** — 25 steps vs 250 (10× faster) |
| Non-commercial research | Either |
| Any commercial use | **SelVA only** — AudioX is CC-BY-NC-4.0 |
---
## Architecture
| Dimension | SelVA (MMAudio) | AudioX-MAF |
|---|---|---|
| Core paradigm | Flow matching | Diffusion (k-diffusion / DPM++) |
| Inference steps | 25 ODE steps (Euler) | 250 diffusion steps (DPM++ 3M SDE) |
| Sample rate | 44.1 kHz (large) / 16 kHz (small) | 48 kHz (fixed) |
| Generator | MM-DiT, velocity prediction | ContinuousMMDiTTransformer |
| Video encoder | Synchformer | Synchformer (AudioX custom re-impl, same concept) |
| VAE / codec | DAC (descript-audio-codec) | DAC + AudioCraft options |
| Text encoder | T5-large | T5 (configurable small → XXL) |
| Video-audio fusion | Cross-attention in MM-DiT | MAF: dual-projection (dim alignment + seq length alignment) |
| Output duration | Fixed 8 s | Configurable via `sample_size` (default ~44 s at 48kHz) |
| Training data | ~2 M samples (MMAudio paper) | 7 M samples (IF-caps dataset, curated) |
| License | MIT | CC-BY-NC-4.0 |
**MAF (Multimodal Adaptive Fusion):** AudioX's key architectural contribution. Instead of directly
concatenating multimodal tokens into the DiT's cross-attention, MAF projects each modality to
match the latent's sequence length via a dedicated linear + transposed-conv stack, then applies
`MMDitSingleBlock` layers for cross-modal fusion. The paper reports this improves cross-modal
alignment particularly for video-to-audio tasks.
**Flow matching vs diffusion:** Flow matching (SelVA) trains a single velocity field to move
directly from noise to data along a straight trajectory — this is why 25 steps suffice. Standard
diffusion (AudioX) approximates a longer stochastic path, requiring 250 steps for quality output.
This is not a quality difference per se; flow matching is simply more efficient.
---
## Capabilities
| Task | SelVA | AudioX |
|---|---|---|
| Video → sound effects | ✓ (primary use case) | ✓ |
| Text → sound effects | Partial (T5 conditions quality but not primary) | ✓ (strong benchmark scores) |
| Video → music | ✗ | ✓ |
| Text → music | ✗ | ✓ |
| Audio inpainting | ✗ | ✓ (mask_args parameter) |
| Music continuation | ✗ | ✓ (init_audio parameter) |
| Variable output duration | ✗ (fixed 8 s) | ✓ |
| Multiple input modalities simultaneously | Partial | ✓ (text + video + audio at once) |
AudioX benchmarks claim superior results on text-to-audio (AudioCaps) and text-to-music
(MusicCaps) vs prior models. Video-to-audio comparison against MMAudio specifically is not
prominently featured in the paper. Perceptual evaluation confirms this: AudioX does not sound
noticeably better than SelVA on video-to-audio tasks. AudioX's advantage is **breadth**
(music, inpainting, variable duration), not raw video-to-audio quality.
---
## Integration Cost
Adding AudioX inference-only nodes to PrismAudio would require:
### New nodes (3 files)
```
nodes/
audiox_model_loader.py AUDIOX_MODEL loader — get_pretrained_model("HKUSTAudio/AudioX-MAF")
audiox_sampler.py wraps generate_diffusion_cond(), inputs: model + text + video + audio
audiox_feature_extractor.py optional — pre-extract Synchformer sync features (caching)
```
### Installation
```bash
pip install git+https://github.com/ZeyueT/AudioX.git
```
New dependencies not currently in PrismAudio:
- `pytorch-lightning==2.4.0`
- `k-diffusion==0.1.1`
- `v-diffusion-pytorch==0.0.2`
- `descript-audio-codec==1.0.0` (already used by SelVA — no conflict, same package)
- `gradio==4.44.1` (optional — only for the upstream Gradio UI)
Model weights: `HKUSTAudio/AudioX-MAF` on HuggingFace (~several GB).
### Inference API surface
```python
from audiox import get_pretrained_model
from audiox.inference.generation import generate_diffusion_cond
model, config = get_pretrained_model("HKUSTAudio/AudioX-MAF")
output = generate_diffusion_cond(
model,
steps=250,
cfg_scale=6.0,
conditioning={
"text_prompt": "a dog barking",
"video_prompt": {"video": frames_tensor, "sync_features": sync_feat},
"seconds_total": 8.0,
},
sample_size=384000, # 8 s at 48kHz
sample_rate=48000,
device="cuda",
)
# output: torch.Tensor (batch, channels, num_samples) float32 [-1, 1]
```
---
## LoRA Training
Adding AudioX LoRA training to PrismAudio is **significantly harder** than the SelVA trainer:
| Aspect | SelVA LoRA | AudioX LoRA |
|---|---|---|
| Loss function | Single MSE velocity loss | Diffusion loss over 250-step schedule |
| Training steps needed | ~2000 steps practical | Unknown — likely much more |
| Step cost | Fast (1 velocity prediction) | Slow (full diffusion forward pass per step) |
| Existing infrastructure | Full trainer + scheduler + experiments | Nothing — would need to build from scratch |
| Noise schedule | Trivial (linear interpolation) | Cosine alpha-sigma schedule |
| Prior art for LoRA | LoRA on flow matching well-studied | Less explored; closer to Stable Diffusion LoRA |
**Conclusion:** AudioX LoRA training is feasible (it would follow SD-style LoRA with the DPM++
noise schedule) but would be a substantial new project. Not worth building until inference nodes
are stable and there is a clear use case that SelVA cannot serve.
---
## License
AudioX weights are released under **CC-BY-NC-4.0** (Creative Commons Non-Commercial).
- Free for personal use, research, and non-commercial projects
- **Cannot be used in commercial products or services** without a separate agreement
- Attribution required
- SelVA/MMAudio: MIT (no restrictions)
If PrismAudio is ever distributed as part of a commercial tool, AudioX nodes must be clearly
opt-in with a license warning, or excluded entirely.
---
## Recommendation
**Short term:** AudioX is not a replacement for SelVA for the current use case (video → custom
sound effects with LoRA fine-tuning). SelVA is faster, has full training infrastructure, and
is MIT licensed.
**When AudioX becomes worth integrating:**
- If you need to generate background music synchronized to video
- If you need audio inpainting (fill a gap in an existing audio track)
- If you need text-to-audio generation without a video input
- After verifying the CC-BY-NC-4.0 license is acceptable for your use
**Estimated integration effort for inference nodes only:** 23 days of work (3 new node files,
dependency management, testing). No changes to existing SelVA nodes required — they would
coexist in the same package.
---
## References
- Paper: arXiv:2503.10522 — *AudioX: Diffusion Transformer for Anything-to-Audio Generation*
- GitHub: https://github.com/ZeyueT/AudioX
- Model weights: https://huggingface.co/HKUSTAudio/AudioX-MAF
- Demo: https://huggingface.co/spaces/Zeyue7/AudioX
- Project page: https://zeyuet.github.io/AudioX/
@@ -1,606 +0,0 @@
# Audio Dataset Pipeline Implementation Plan
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** Add 5 chainable ComfyUI nodes for in-memory audio dataset preprocessing: load → resample → LUFS normalize → inspect/filter → extract single item.
**Architecture:** Single new file `nodes/selva_dataset_pipeline.py` defines a custom `AUDIO_DATASET` type (list of dicts) and all 5 node classes. Nodes are stateless transforms — each takes `AUDIO_DATASET` and returns `AUDIO_DATASET`. No disk I/O except in the Loader. Register all nodes in `nodes/__init__.py`.
**Tech Stack:** `pyloudnorm` (BS.1770-4 LUFS), `soxr` (VHQ resampling), `torchaudio`, `torch`. Both confirmed present in the ComfyUI environment at `/media/p5/miniforge3/envs/latestcomfyui`.
---
## The `AUDIO_DATASET` type
Used as the ComfyUI type string `"AUDIO_DATASET"`. At runtime it is a Python list of dicts:
```python
[
{
"waveform": torch.Tensor, # shape [1, C, L], float32, range [-1, 1]
"sample_rate": int,
"name": str, # original filename stem, for reporting
},
...
]
```
---
### Task 1: Create the file skeleton and AUDIO_DATASET constant
**Files:**
- Create: `nodes/selva_dataset_pipeline.py`
**Step 1: Write the file with imports and type constant only**
```python
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
```
**Step 2: Verify import works (no test framework needed — just a quick smoke check)**
```bash
cd /media/p5/Comfyui-Prismaudio
python3 -c "from nodes.selva_dataset_pipeline import AUDIO_DATASET; print(AUDIO_DATASET)"
```
Expected output: `AUDIO_DATASET`
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add audio dataset pipeline skeleton"
```
---
### Task 2: SelvaDatasetLoader
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Loader class**
```python
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = torchaudio.load(str(f)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader
node = SelvaDatasetLoader()
ds, = node.load('/media/unraid/davinci/Selva/BJ')
print(len(ds), 'clips', ds[0]['name'], ds[0]['waveform'].shape, ds[0]['sample_rate'])
"
```
Expected: prints clip count, first clip name, shape like `torch.Size([1, 2, 352800])`, sample rate.
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLoader node"
```
---
### Task 3: SelvaDatasetResampler
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the Resampler class**
Uses `soxr` directly for VHQ quality. `soxr.resample` operates on numpy arrays, shape `[L, C]` (time-first).
```python
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetResampler
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetResampler().resample(ds, 44100)
print('ok', ds2[0]['sample_rate'], ds2[0]['waveform'].shape)
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetResampler node (soxr VHQ)"
```
---
### Task 4: SelvaDatasetLUFSNormalizer
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the LUFS normalizer class**
`pyloudnorm.Meter` requires numpy float64 array shape `[L]` (mono) or `[L, C]` (multichannel, channels last). True peak limit applied after gain.
```python
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetLUFSNormalizer
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
ds2, = SelvaDatasetLUFSNormalizer().normalize(ds, -23.0, -1.0)
print('ok', ds2[0]['name'], ds2[0]['waveform'].abs().max().item())
"
```
Expected: peak ≤ ~0.89 (≈ -1 dBTP).
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)"
```
---
### Task 5: SelvaDatasetInspector
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add helper functions for artifact detection**
```python
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
window = torch.hann_window(n_fft)
mono = wav[0].mean(0) # [L]
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
```
**Step 2: Add the Inspector class**
```python
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
```
**Step 3: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetInspector
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
clean, report = SelvaDatasetInspector().inspect(ds, skip_rejected=False, min_snr_db=15.0, check_codec_artifacts=True)
print(report)
"
```
Expected: report with per-clip OK/FLAGGED lines and summary counts.
**Step 4: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)"
```
---
### Task 6: SelvaDatasetItemExtractor
**Files:**
- Modify: `nodes/selva_dataset_pipeline.py`
**Step 1: Add the extractor class**
```python
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
```
**Step 2: Smoke test**
```bash
python3 -c "
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetItemExtractor
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
audio, name, total = SelvaDatasetItemExtractor().extract(ds, 0)
print(name, total, audio['waveform'].shape, audio['sample_rate'])
"
```
**Step 3: Commit**
```bash
git add nodes/selva_dataset_pipeline.py
git commit -m "feat: add SelvaDatasetItemExtractor node"
```
---
### Task 7: Register all nodes in __init__.py
**Files:**
- Modify: `nodes/__init__.py:4-25`
**Step 1: Add the 5 new entries to `_NODES`**
Add inside the `_NODES` dict, after `"SelvaDittoOptimizer"`:
```python
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
```
**Step 2: Verify registration**
```bash
python3 -c "
import sys; sys.path.insert(0, '/media/p5/Comfyui-Prismaudio')
from nodes import NODE_CLASS_MAPPINGS
keys = [k for k in NODE_CLASS_MAPPINGS if 'Dataset' in k]
print(keys)
"
```
Expected: list of 5 dataset node keys.
**Step 3: Final commit**
```bash
git add nodes/__init__.py
git commit -m "feat: register audio dataset pipeline nodes in __init__.py"
```
---
## Summary
5 nodes in `nodes/selva_dataset_pipeline.py`, all registered in `__init__.py`:
| Node | In | Out |
|------|----|-----|
| SelvaDatasetLoader | folder path | AUDIO_DATASET |
| SelvaDatasetResampler | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetLUFSNormalizer | AUDIO_DATASET | AUDIO_DATASET |
| SelvaDatasetInspector | AUDIO_DATASET | AUDIO_DATASET + STRING |
| SelvaDatasetItemExtractor | AUDIO_DATASET + index | AUDIO + name + total |
Dependencies: `pyloudnorm`, `soxr` — both confirmed present in the ComfyUI env.
-77
View File
@@ -1,77 +0,0 @@
{
"name": "alpha_scale_sweep",
"description": "Fix LoRA noise contamination (flatness 0.013→0.094 at alpha=rank). Root cause: alpha=rank (scale=1.0) at high rank drowns base model priors. Testing dramatically lower alpha to nudge rather than overwrite. All runs at lr=3e-4 (best stable LR from r128_sweet_spot).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/alpha_scale_sweep",
"base": {
"steps": 6000,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "g1_r16_alpha4",
"group": "conservative",
"description": "Back to basics: rank=16 alpha=4 (scale=0.25). Small adapter, gentle scale — cleanest possible LoRA signal.",
"rank": 16,
"alpha": 4.0
},
{
"id": "g1_r16_alpha16",
"group": "conservative",
"description": "rank=16 alpha=16 (scale=1.0) — the original default. Reference point: is the noise issue rank-specific or universal?",
"rank": 16,
"alpha": 16.0
},
{
"id": "g2_r32_alpha8",
"group": "mid",
"description": "rank=32 alpha=8 (scale=0.25). More capacity than r16 but still gentle scale.",
"rank": 32,
"alpha": 8.0
},
{
"id": "g2_r32_alpha32",
"group": "mid",
"description": "rank=32 alpha=32 (scale=1.0). Same rank, full scale — isolates whether scale or rank is causing noise.",
"rank": 32,
"alpha": 32.0
},
{
"id": "g3_r128_alpha8",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=8 (scale=0.0625). High capacity, very gentle contribution — can r128 stay clean at low alpha?",
"rank": 128,
"alpha": 8.0
},
{
"id": "g3_r128_alpha16",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=16 (scale=0.125). Slightly more signal than alpha=8.",
"rank": 128,
"alpha": 16.0
},
{
"id": "g3_r128_alpha32",
"group": "high_rank_low_alpha",
"description": "rank=128 alpha=32 (scale=0.25). Same scale as r16_alpha4 and r32_alpha8 — comparable contribution across ranks.",
"rank": 128,
"alpha": 32.0
}
]
}
-31
View File
@@ -1,31 +0,0 @@
{
"name": "bigvgan_disc_fm_retest",
"description": "Retest discriminator feature matching after bfloat16 dtype fix. Uses optimal config from overnight sweep (snake_alpha, GAFilter, lr=1e-4, phase=1.0, L2-SP=1e-3, 5000 steps).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_disc_fm_retest",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_5k_control",
"description": "Control: best config from overnight sweep without discriminator. Baseline for A/B comparison."
},
{
"id": "disc_fm_5k",
"description": "Discriminator feature matching at 5k steps. Tests if perceptual FM loss improves over mel+phase alone.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
}
]
}
@@ -1,35 +0,0 @@
{
"name": "bigvgan_optimized_dataset",
"description": "BigVGAN fine-tuning on optimized dataset (134 clips, 44.1kHz, LUFS-normalized). Standard mode (no LoRA) — trains decoder to faithfully reconstruct target domain audio from mel spectrograms. Uses optimal config from prior sweeps.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_optimized_dataset",
"base": {
"train_mode": "snake_alpha_only",
"steps": 5000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42
},
"experiments": [
{
"id": "standard_5k",
"description": "Standard mode: mel from clean FLAC → BigVGAN → reconstruct FLAC. No LoRA. Directly improves VAE roundtrip quality."
},
{
"id": "disc_fm_5k",
"description": "Standard mode + discriminator feature matching. Tests if perceptual loss helps on clean audio reconstruction.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "standard_10k",
"description": "Extended 10k steps. More data passes on 134 clips may extract more from the optimized dataset.",
"steps": 10000
}
]
}
-65
View File
@@ -1,65 +0,0 @@
{
"name": "bigvgan_overnight",
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
"base": {
"train_mode": "snake_alpha_only",
"steps": 3000,
"lr": 1e-4,
"batch_size": 8,
"segment_seconds": 0.5,
"lambda_l2sp": 1e-3,
"use_gafilter": true,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 1000,
"seed": 42,
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
},
"experiments": [
{
"id": "snake_3k_baseline",
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
},
{
"id": "snake_5k",
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
"steps": 5000
},
{
"id": "snake_no_gafilter",
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
"use_gafilter": false
},
{
"id": "snake_no_phase",
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
"lambda_phase": 0.0
},
{
"id": "snake_phase_2",
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
"lambda_phase": 2.0
},
{
"id": "snake_lr5e-5",
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
"lr": 5e-5,
"steps": 5000
},
{
"id": "snake_disc_fm",
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
},
{
"id": "all_2k_l2sp1e-2",
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
"train_mode": "all_params",
"steps": 2000,
"lr": 1e-5,
"lambda_l2sp": 1e-2
}
]
}
-39
View File
@@ -1,39 +0,0 @@
{
"name": "eval_r128_candidates",
"description": "Top candidates from r128_sweet_spot. Comparing the two lowest-loss runs, the stable lr=3e-4, and the curriculum run that hit 0.161 before regressing. Baseline included as perceptual reference.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_dir": "/media/unraid/davinci/Selva/BJ/evals/r128_candidates",
"steps": 25,
"seed": 42,
"adapters": [
{
"id": "baseline",
"description": "No LoRA — base model output for perceptual reference"
},
{
"id": "lr_5e4_r128",
"description": "Best loss overall (0.137), still descending at step 10k",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_5e4/adapter_final.pt"
},
{
"id": "lr_3e4_r256",
"description": "Tied with lr_5e4 at 0.139, higher rank — does extra capacity help perceptually?",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g4_r256_lr_3e4/adapter_final.pt"
},
{
"id": "lr_3e4_r128",
"description": "Stable plateau from step 4k to 10k (0.221) — visually confirmed clean spectrograms",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_3e4/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4",
"description": "Best min loss of all (0.161 at step 6k), regressed to 0.193 after curriculum switch — curious if the early checkpoint sounds better",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_final.pt"
},
{
"id": "curriculum_lr_3e4_step6000",
"description": "Same run at its actual best step (before regression) — compare against adapter_final to hear the regression",
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_step06000.pt"
}
]
}
-33
View File
@@ -1,33 +0,0 @@
{
"name": "lora_logit_cosine_combo",
"description": "Combine the two best findings from optimized dataset sweep: logit-normal timestep sampling + cosine LR schedule. Both individually outperformed baseline by large margins (56% and 68% lower loss). Tests if gains stack.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_logit_cosine_combo",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "logit_normal_cosine",
"description": "Logit-normal timesteps + cosine LR decay. Combining the two best individual improvements.",
"timestep_mode": "logit_normal",
"lr_schedule": "cosine"
},
{
"id": "logit_normal_control",
"description": "Control: logit-normal only (constant LR). Reproduces previous winner for direct comparison.",
"timestep_mode": "logit_normal"
}
]
}
-64
View File
@@ -1,64 +0,0 @@
{
"name": "lora_optimized_dataset",
"description": "LoRA training on optimized dataset (134 clips: resampled 44.1kHz, LUFS-normalized, spectral matched, HF smoothed, gain-augmented). Tests latent augmentation and schedule variants on top of known-best config (PiSSA, rank=128, lr=3e-4).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_optimized_dataset",
"base": {
"rank": 128,
"lr": 3e-4,
"steps": 5000,
"batch_size": 16,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_mode": "pissa",
"use_rslora": true,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lr_schedule": "constant"
},
"experiments": [
{
"id": "baseline",
"description": "Control: known-best config (PiSSA r128 lr=3e-4) on the optimized dataset. No latent augmentation."
},
{
"id": "latent_mixup",
"description": "Latent mixup alpha=0.4 (MusicLDM). Tests if mixing training latents reduces memorization on 134 clips.",
"latent_mixup_alpha": 0.4
},
{
"id": "latent_noise",
"description": "Latent noise sigma=0.02. Mild Gaussian noise on training latents for regularization.",
"latent_noise_sigma": 0.02
},
{
"id": "mixup_and_noise",
"description": "Both latent mixup (0.4) and noise (0.02). Combined regularization.",
"latent_mixup_alpha": 0.4,
"latent_noise_sigma": 0.02
},
{
"id": "cosine_schedule",
"description": "Cosine LR decay. lr=3e-4 was stable with constant, but cosine may extract more from 5k steps.",
"lr_schedule": "cosine"
},
{
"id": "cosine_mixup",
"description": "Cosine LR + latent mixup. Best regularization combo candidate.",
"lr_schedule": "cosine",
"latent_mixup_alpha": 0.4
},
{
"id": "logit_normal",
"description": "Logit-normal timestep sampling (sigma=1.0). Concentrates training near t=0.5 where flow matching is hardest.",
"timestep_mode": "logit_normal"
},
{
"id": "curriculum_mixup",
"description": "Curriculum timesteps (logit_normal first 60%, then uniform) + latent mixup. Full regularization stack.",
"timestep_mode": "curriculum",
"latent_mixup_alpha": 0.4
}
]
}
-62
View File
@@ -1,62 +0,0 @@
{
"name": "pissa_sweep",
"description": "PiSSA vs standard init ablation at rank 128. Best prior config (lr=3e-4, bs=16, 10k steps) as baseline. PiSSA starts on-manifold via SVD init — should eliminate intruder dimensions. rsLoRA stabilises scaling at high rank.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": true
},
"experiments": [
{
"id": "standard_baseline",
"description": "Standard Kaiming init + classic alpha/rank scaling. Replicates best prior config for A/B comparison.",
"init_mode": "standard",
"use_rslora": false
},
{
"id": "pissa_rslora",
"description": "PiSSA init + rsLoRA scaling. Full Tier-S config. Should start on-manifold and avoid intruder dimensions."
},
{
"id": "pissa_classic_scale",
"description": "PiSSA init + classic alpha/rank scaling. Isolates PiSSA contribution from rsLoRA.",
"use_rslora": false
},
{
"id": "standard_rslora",
"description": "Standard init + rsLoRA only. Isolates rsLoRA contribution from PiSSA.",
"init_mode": "standard"
},
{
"id": "pissa_rslora_lr1e-4",
"description": "PiSSA+rsLoRA at lower lr=1e-4. PiSSA starts closer to optimum — may need less aggressive lr.",
"lr": 1e-4
},
{
"id": "pissa_rslora_lr5e-4",
"description": "PiSSA+rsLoRA at higher lr=5e-4. Test if on-manifold start tolerates faster learning.",
"lr": 5e-4
},
{
"id": "pissa_rslora_dropout",
"description": "PiSSA+rsLoRA with dropout 0.05. Note: PiSSA forces dropout=0 (principal components should not be dropped) — this tests standard init with rsLoRA + dropout.",
"init_mode": "standard",
"lora_dropout": 0.05
}
]
}
-103
View File
@@ -1,103 +0,0 @@
{
"name": "r128_sweet_spot",
"description": "Find the noise-free sweet spot on rank 128. LoRA+ ratio=16 caused noise — testing higher base LR without LoRA+ as a cleaner alternative. Target loss range 0.250.35. Also probing rank 256 since 102GB VRAM allows it.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r128_lr_2e4",
"group": "lr",
"description": "LR=2e-4. Conservative 2× step up from baseline — noise-free descent toward sweet spot.",
"lr": 2e-4
},
{
"id": "g1_r128_lr_3e4",
"group": "lr",
"description": "LR=3e-4. 3× baseline — landed at 0.41 on r64, should reach 0.250.35 on r128.",
"lr": 3e-4
},
{
"id": "g1_r128_lr_5e4",
"group": "lr",
"description": "LR=5e-4. Aggressive but no LoRA+ B-matrix asymmetry — cleaner noise profile.",
"lr": 5e-4
},
{
"id": "g2_r128_curriculum",
"group": "curriculum",
"description": "Curriculum only at baseline LR. Clean slow descent — reference for what curriculum contributes alone.",
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum",
"group": "curriculum",
"description": "LR=3e-4 + curriculum. Speed of higher LR with coverage of curriculum — no LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum"
},
{
"id": "g2_r128_lr_3e4_curriculum_dropout",
"group": "curriculum",
"description": "LR=3e-4 + curriculum + dropout=0.05. Full controlled stack without LoRA+.",
"lr": 3e-4,
"timestep_mode": "curriculum",
"lora_dropout": 0.05
},
{
"id": "g3_r128_lora_plus_4",
"group": "lora_plus",
"description": "LoRA+ ratio=4 (lr_B=4e-4). Much more conservative than ratio=16 — tests if noise came from ratio not the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g4_r256_baseline",
"group": "rank256",
"description": "Rank 256 at baseline LR. 102GB VRAM makes this viable — does more capacity keep helping?",
"rank": 256
},
{
"id": "g4_r256_lr_3e4",
"group": "rank256",
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
"rank": 256,
"lr": 3e-4
},
{
"id": "g5_r128_lr_2e4_cosine",
"group": "cosine",
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 60008000 by decaying LR to ~0 instead of staying flat.",
"lr": 2e-4,
"lr_schedule": "cosine"
},
{
"id": "g5_r128_lr_3e4_cosine",
"group": "cosine",
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
"lr": 3e-4,
"lr_schedule": "cosine"
}
]
}
-130
View File
@@ -1,130 +0,0 @@
{
"name": "r64_overnight",
"description": "Focused rank-64 overnight sweep. All experiments use rank 64 as base — confirmed best from tier1_thorough early results. 8000 steps to reach convergence (none converged at 4000).",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r64_overnight",
"base": {
"steps": 8000,
"rank": 64,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_r64_baseline",
"group": "rank",
"description": "Rank 64 baseline — clean reference at 8000 steps."
},
{
"id": "g1_r128_baseline",
"group": "rank",
"description": "Rank 128 — 102GB VRAM makes this free. Does doubling rank from 64 help further?",
"rank": 128
},
{
"id": "g2_r64_alpha_32",
"group": "alpha",
"description": "Rank 64 alpha=32 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 32.0
},
{
"id": "g2_r64_alpha_16",
"group": "alpha",
"description": "Rank 64 alpha=16 (scale=0.25). More aggressive scale reduction — may over-constrain.",
"alpha": 16.0
},
{
"id": "g3_r64_lora_plus",
"group": "regularisation",
"description": "LoRA+ ratio=16. lr_B = 16 × lr_A. Faster convergence at constant step budget.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_r64_dropout_0.05",
"group": "regularisation",
"description": "Dropout=0.05. Light sparsity regularisation on LoRA path.",
"lora_dropout": 0.05
},
{
"id": "g3_r64_dropout_0.1",
"group": "regularisation",
"description": "Dropout=0.1. Stronger regularisation — tests if 49 clips needs heavier constraint.",
"lora_dropout": 0.1
},
{
"id": "g3_r64_curriculum",
"group": "regularisation",
"description": "Curriculum sampling: logit_normal for steps 1-4800, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_r64_lr_low",
"group": "lr",
"description": "LR=3e-5. 3× lower — checks if 1e-4 is overshooting at rank 64.",
"lr": 3e-5
},
{
"id": "g4_r64_lr_high",
"group": "lr",
"description": "LR=3e-4. 3× higher — may converge faster but risk instability.",
"lr": 3e-4
},
{
"id": "g5_r64_target_full",
"group": "target",
"description": "Rank 64 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g5_r128_target_full",
"group": "target",
"description": "Rank 128 + full target. Maximum possible coverage with available VRAM.",
"rank": 128,
"target": "attn.qkv linear1"
},
{
"id": "g6_r64_full_tier1",
"group": "combined",
"description": "All Tier 1 at rank 64: LoRA+ 16 + dropout 0.05 + curriculum. Full stack at 8000 steps.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r64_alpha32_full",
"group": "combined",
"description": "Rank 64 alpha=32 + all Tier 1. Best alpha scaling + best regularisation stack.",
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g6_r128_full_tier1",
"group": "combined",
"description": "Rank 128 + all Tier 1. Tests if more capacity + regularisation beats rank 64 full.",
"rank": 128,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
}
]
}
-52
View File
@@ -1,52 +0,0 @@
{
"name": "ti_sweep_1",
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
"base": {
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"save_every": 1000,
"seed": 42,
"init_text": "",
"lr": 2e-4,
"n_tokens": 4,
"inject_mode": "suffix"
},
"experiments": [
{
"id": "n4_baseline",
"group": "reference",
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
},
{
"id": "n4_clamped",
"group": "norm_clamp",
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
},
{
"id": "n4_prefix_clamped",
"group": "norm_clamp",
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
"inject_mode": "prefix"
},
{
"id": "n8_prefix_clamped",
"group": "norm_clamp",
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
"n_tokens": 8,
"inject_mode": "prefix"
},
{
"id": "n4_prefix_warm_clamped",
"group": "norm_clamp",
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
"inject_mode": "prefix",
"init_text": "mechanical impact sound design"
}
]
}
-61
View File
@@ -1,61 +0,0 @@
{
"name": "tier1_sweep",
"description": "Ablation of Tier 1 improvements: LoRA+, dropout, curriculum sampling. Baseline = uniform, no regularisation.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "lora_sweeps/tier1_sweep",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "baseline",
"description": "Standard LoRA — no Tier 1 changes. Reference point."
},
{
"id": "lora_plus_16",
"description": "LoRA+ only: lr_B = 16 * lr_A. Should converge faster in early steps.",
"lora_plus_ratio": 16.0
},
{
"id": "dropout_0.05",
"description": "LoRA dropout 0.05 only. Light regularisation for 49-clip dataset.",
"lora_dropout": 0.05
},
{
"id": "dropout_0.1",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "curriculum",
"description": "Curriculum sampling only: logit_normal for steps 1-2400, then uniform. Should improve convergence vs pure uniform.",
"timestep_mode": "curriculum"
},
{
"id": "full_tier1",
"description": "All Tier 1 combined: LoRA+ + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "rank_64",
"description": "Rank 64 baseline — MMAudio LoRA guide default. More expressive adapter for 49-clip dataset.",
"rank": 64
}
]
}
-144
View File
@@ -1,144 +0,0 @@
{
"name": "tier1_thorough",
"description": "Full overnight Tier 1 ablation on 49-clip BJ dataset. 4 groups: rank, alpha, regularisation, and best combinations. ~10-12h depending on GPU.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/tier1_thorough",
"base": {
"steps": 4000,
"rank": 16,
"alpha": 0.0,
"lr": 1e-4,
"batch_size": 16,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 1000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0
},
"experiments": [
{
"id": "g1_rank_16",
"group": "rank",
"description": "Rank 16 baseline — reference point for all groups."
},
{
"id": "g1_rank_32",
"group": "rank",
"description": "Rank 32 — midpoint. Does doubling rank improve quality without overfitting?",
"rank": 32
},
{
"id": "g1_rank_64",
"group": "rank",
"description": "Rank 64 — MMAudio LoRA guide default. Maximum expressiveness at 49 clips.",
"rank": 64
},
{
"id": "g2_alpha_half_r16",
"group": "alpha",
"description": "Alpha=8 with rank 16 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
"alpha": 8.0
},
{
"id": "g2_alpha_half_r64",
"group": "alpha",
"description": "Alpha=32 with rank 64 (scale=0.5). Best-practice scaling for high-rank adapters.",
"rank": 64,
"alpha": 32.0
},
{
"id": "g3_lora_plus_4",
"group": "regularisation",
"description": "LoRA+ ratio=4 — conservative asymmetric LR. Lower bound for the technique.",
"lora_plus_ratio": 4.0
},
{
"id": "g3_lora_plus_16",
"group": "regularisation",
"description": "LoRA+ ratio=16 — standard from FLUX LoRA literature. Faster early convergence.",
"lora_plus_ratio": 16.0
},
{
"id": "g3_dropout_0.05",
"group": "regularisation",
"description": "LoRA dropout 0.05 only. Light sparsity regularisation (arXiv:2404.09610).",
"lora_dropout": 0.05
},
{
"id": "g3_dropout_0.1",
"group": "regularisation",
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
"lora_dropout": 0.1
},
{
"id": "g3_curriculum",
"group": "regularisation",
"description": "Curriculum sampling only: logit_normal steps 1-2400, then uniform (arXiv:2603.12517).",
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r16",
"group": "combined",
"description": "All Tier 1 at rank 16: LoRA+ 16 + dropout 0.05 + curriculum.",
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g4_full_r64",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32. Best expressiveness + best regularisation.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum"
},
{
"id": "g5_lr_low",
"group": "lr",
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
"lr": 3e-5
},
{
"id": "g5_lr_high",
"group": "lr",
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
"lr": 3e-4
},
{
"id": "g6_target_full_r16",
"group": "target",
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
"target": "attn.qkv linear1"
},
{
"id": "g6_target_full_r64",
"group": "target",
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
"rank": 64,
"alpha": 32.0,
"target": "attn.qkv linear1"
},
{
"id": "g4_full_r64_6k",
"group": "combined",
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
"rank": 64,
"alpha": 32.0,
"lora_plus_ratio": 16.0,
"lora_dropout": 0.05,
"timestep_mode": "curriculum",
"steps": 6000,
"save_every": 1000
}
]
}
-30
View File
@@ -1,30 +0,0 @@
{
"name": "vocoder_finetune",
"description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.",
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune",
"base": {
"steps": 10000,
"rank": 128,
"alpha": 0.0,
"lr": 3e-4,
"batch_size": 16,
"warmup_steps": 200,
"grad_accum": 1,
"save_every": 2000,
"seed": 42,
"target": "attn.qkv",
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant"
},
"experiments": [
{
"id": "r128_lr_3e4_bj_vocoder",
"description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4."
}
]
}
-30
View File
@@ -5,36 +5,6 @@ _NODES = {
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"), "SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"), "SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"), "SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
"SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"),
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
"SelvaDatasetCompressor": (".selva_dataset_pipeline", "SelvaDatasetCompressor", "SelVA Dataset Compressor"),
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
} }
for key, (module_path, class_name, display_name) in _NODES.items(): for key, (module_path, class_name, display_name) in _NODES.items():
@@ -1,201 +0,0 @@
"""SelVA Activation Steering Extractor.
Computes per-block steering vectors by running the frozen generator on the
training dataset and recording how target style's conditioning shifts the DiT hidden
states vs. empty/unconditional conditioning.
For each block i:
steering[i] = mean(latent_hidden | target style conditions)
- mean(latent_hidden | empty conditions)
The resulting vectors are injected at inference time (via SelVA Sampler's
steering_strength input) to nudge the denoising trajectory toward target style's
activation patterns without modifying any model weights.
"""
import random
from pathlib import Path
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import _prepare_dataset
def _collect_activations(generator, conditions, latent, t_tensor):
"""Run one predict_flow call, collecting latent hidden states per block.
Returns a list of [seq, hidden_dim] float32 CPU tensors,
one per block (joint_blocks first, then fused_blocks).
"""
activations = []
def make_hook(is_joint):
def hook(module, input, output):
h = output[0] if is_joint else output
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
return hook
handles = []
for block in generator.joint_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
for block in generator.fused_blocks:
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
try:
with torch.no_grad():
generator.predict_flow(latent, t_tensor, conditions)
finally:
for h in handles:
h.remove()
return activations # list of n_blocks tensors [seq, hidden]
class SelvaActivationSteeringExtractor:
"""Computes activation steering vectors from a training dataset.
Runs the frozen generator on N clips at random timesteps with both
target style-conditioned and empty-conditioned inputs, then saves the mean
difference per DiT block to a .pt file.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "extract"
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("steering_path",)
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
DESCRIPTION = (
"Computes per-block activation steering vectors: mean(target style activations) "
"mean(empty activations) at each DiT block. Load the result with "
"SelVA Activation Steering Loader and connect to the Sampler."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
}),
"output_path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
}),
"n_samples": ("INT", {
"default": 16, "min": 1, "max": 256,
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
},
}
def extract(self, model, data_dir, output_path, n_samples, seed):
device = get_device()
dtype = model["dtype"]
seq_cfg = model["seq_cfg"]
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
print(f"[Steering] data_dir = {data_dir}", flush=True)
print(f"[Steering] output = {out_path}\n", flush=True)
dataset = _prepare_dataset(model, data_dir, device)
generator = model["generator"]
generator.eval()
torch.manual_seed(seed)
random.seed(seed)
indices = random.choices(range(len(dataset)), k=n_samples)
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
style_sums = [None] * n_blocks
empty_sums = [None] * n_blocks
counts = [0] * n_blocks
pbar = comfy.utils.ProgressBar(n_samples)
for sample_i, clip_idx in enumerate(indices):
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
clip_latent_seq_len = x1_cpu.shape[1]
generator.update_seq_lengths(
latent_seq_len=clip_latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = generator.get_empty_conditions(bs=1)
# Random timestep and noise latent for this clip
t_val = torch.rand(1).item()
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
latent = torch.randn(
1, clip_latent_seq_len, generator.latent_dim,
device=device, dtype=dtype,
)
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
if style_sums[i] is None:
style_sums[i] = st.clone()
empty_sums[i] = em.clone()
else:
style_sums[i] += st
empty_sums[i] += em
counts[i] += 1
pbar.update(1)
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
# Steering vector per block: mean(target style) - mean(empty)
steering_vectors = []
for i in range(n_blocks):
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
steering_vectors.append(vec)
norm = vec.norm().item()
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
n_joint = len(generator.joint_blocks)
payload = {
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
"n_blocks": n_blocks,
"n_joint": n_joint,
"n_fused": len(generator.fused_blocks),
"latent_seq_len": seq_cfg.latent_seq_len,
"n_samples": n_samples,
"seed": seed,
"mode": model["mode"],
"variant": model["variant"],
}
torch.save(payload, str(out_path))
print(f"\n[Steering] Saved: {out_path}", flush=True)
soft_empty_cache()
return (str(out_path),)
-62
View File
@@ -1,62 +0,0 @@
"""SelVA Activation Steering Loader.
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaActivationSteeringLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("STEERING_VECTORS",)
RETURN_NAMES = ("steering_vectors",)
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
DESCRIPTION = (
"Loads activation steering vectors from a .pt file produced by "
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
"denoising toward the target activation patterns."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "steering_vectors.pt",
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[Steering] File not found: {p}")
payload = torch.load(str(p), map_location="cpu", weights_only=False)
n_blocks = payload["n_blocks"]
n_joint = payload["n_joint"]
n_fused = payload["n_fused"]
n_vecs = len(payload["steering_vectors"])
print(f"[Steering] Loaded: {p}", flush=True)
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
f"latent_seq_len={payload['latent_seq_len']} "
f"n_samples={payload['n_samples']}", flush=True)
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
mean_norm = sum(norms) / len(norms)
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
return (payload,)
-153
View File
@@ -1,153 +0,0 @@
"""SelVA Audio Post-Processing nodes.
Post-generation enhancement applied to standard AUDIO outputs:
SelvaHarmonicExciter — multi-band harmonic exciter (HPF → tanh → mix)
SelvaOutputNormalizer — LUFS normalization + true peak limiting
"""
import numpy as np
import torch
from .utils import SELVA_CATEGORY
class SelvaHarmonicExciter:
"""Multi-band harmonic exciter for post-generation enhancement.
Isolates high-frequency content above a cutoff, applies tanh saturation
to generate 2nd/3rd harmonics, then mixes back with the dry signal.
Restores harmonic richness lost during BigVGAN vocoder reconstruction.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 3000.0, "min": 500.0, "max": 16000.0, "step": 100.0,
"tooltip": "Highpass cutoff frequency in Hz. Only content above this is excited. "
"3000 Hz targets the upper harmonics BigVGAN tends to smear.",
}),
"drive": ("FLOAT", {
"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.5,
"tooltip": "Saturation drive. Higher = more harmonics generated. "
"2-3 is subtle, 5+ is aggressive.",
}),
"mix": ("FLOAT", {
"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Wet/dry blend. 0.1-0.2 is subtle enhancement, "
"0.5+ is aggressive harmonic addition.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "excite"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Multi-band harmonic exciter. Applies tanh saturation to the high-frequency band "
"to restore harmonics lost during BigVGAN vocoder reconstruction. "
"Uses pedalboard.HighpassFilter for band isolation."
)
def excite(self, audio, cutoff_hz: float, drive: float, mix: float):
from pedalboard import Pedalboard, HighpassFilter
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
wav_np = wav.float().numpy() # [C, T]
# Isolate HF band
board = Pedalboard([HighpassFilter(cutoff_frequency_hz=cutoff_hz)])
hf = board(wav_np, sr) # [C, T]
# Tanh saturation — normalize by drive so output stays in [-1, 1]
excited = np.tanh(hf * drive) / max(drive, 1.0)
# Mix back with dry
mixed = wav_np + mix * excited
# Soft clip to prevent going over
mixed = np.tanh(mixed)
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, T]
print(
f"[HarmonicExciter] cutoff={cutoff_hz}Hz drive={drive} mix={mix:.0%}",
flush=True,
)
return ({"waveform": wav_out, "sample_rate": sr},)
class SelvaOutputNormalizer:
"""Normalize generated audio to a target LUFS level with true peak limiting.
Apply as the final node before saving — brings generated audio to a
consistent loudness target regardless of input video loudness variation.
Uses pyloudnorm (BS.1770-4).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"target_lufs": ("FLOAT", {
"default": -14.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. "
"-14 LUFS for streaming (Spotify/YouTube), "
"-9 to -7 for production masters.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP applied after LUFS gain.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize output audio to a target LUFS level (BS.1770-4) with true peak limiting. "
"Apply as the last node before saving. Uses pyloudnorm."
)
def normalize(self, audio, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
wav = audio["waveform"][0] # [C, T]
sr = audio["sample_rate"]
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
wav_np = wav.permute(1, 0).double().numpy() # [T, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [T] mono
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(wav_np)
if not np.isfinite(loudness):
print("[OutputNormalizer] Could not measure loudness — clip too short or silent. Passing through.", flush=True)
return (audio,)
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_out = wav * gain_linear
peak = wav_out.abs().max().item()
if peak > tp_linear:
wav_out = wav_out * (tp_linear / peak)
print(
f"[OutputNormalizer] {loudness:.1f} LUFS → {target_lufs} LUFS "
f"gain={gain_db:+.1f}dB TP={true_peak_dbtp}dBTP",
flush=True,
)
return ({"waveform": wav_out.unsqueeze(0), "sample_rate": sr},)
-293
View File
@@ -1,293 +0,0 @@
"""SelVA Audio Preprocessors — condition training clips for codec compatibility.
Two nodes that reduce the domain mismatch between custom training audio and the
MMAudio VAE's expected spectral distribution, improving LoRA training quality:
SelvaHfSmoother — soft low-pass blend to attenuate extreme HF content
SelvaSpectralMatcher — adaptive per-band EQ toward the codec's training distribution
Root cause they address: MMAudio was trained on natural sounds (speech, foley, env)
with limited engineered HF content. The BigVGANv2 vocoder (frozen, pre-trained) handles
the codec's HF reconstruction poorly for sound design / music training clips, because
those clips land in a latent-space region the vocoder never saw during training.
Recommended order: SpectralMatcher → HfSmoother → feature extraction → LoRA training.
"""
import numpy as np
import torch
import torchaudio.functional as AF
from .utils import SELVA_CATEGORY
# ── Mel filterbank (same algorithm as selva_core/ext/mel_converter.py) ────────
def _mel_filterbank(sr: int, n_fft: int, n_mels: int,
fmin: float, fmax: float) -> torch.Tensor:
"""Returns mel filterbank matrix [n_mels, n_fft//2+1]."""
def hz_to_mel(f):
return 2595.0 * np.log10(1.0 + np.asarray(f) / 700.0)
def mel_to_hz(m):
return 700.0 * (10.0 ** (np.asarray(m) / 2595.0) - 1.0)
n_freqs = n_fft // 2 + 1
fft_freqs = np.linspace(0.0, sr / 2.0, n_freqs)
mel_pts = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
hz_pts = mel_to_hz(mel_pts)
fb = np.zeros((n_mels, n_freqs), dtype=np.float32)
for m in range(1, n_mels + 1):
lo, mid, hi = hz_pts[m - 1], hz_pts[m], hz_pts[m + 1]
up = (fft_freqs - lo) / (mid - lo + 1e-12)
down = (hi - fft_freqs) / (hi - mid + 1e-12)
fb[m - 1] = np.maximum(0.0, np.minimum(up, down))
return torch.from_numpy(fb)
# ── VAE target log-mel means (source: selva_core/ext/autoencoder/vae.py) ──────
# These are the per-band expected log-mel energy means from MMAudio's training data.
# Used as the spectral matching target: clips are EQ'd to match this profile.
_TARGET_MEAN_80D = [
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439,
-1.2922, -1.2927, -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912,
-1.4313, -1.4152, -1.4527, -1.4728, -1.4568, -1.5101, -1.5051, -1.5172,
-1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, -1.6081, -1.6331,
-1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377,
-1.8417, -1.8643, -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673,
-1.9824, -2.0042, -2.0215, -2.0436, -2.0766, -2.1064, -2.1418, -2.1855,
-2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, -2.4659, -2.5072,
-2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673,
]
_TARGET_MEAN_128D = [
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006,
-2.2357, -2.4597, -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047,
-2.7483, -2.5926, -2.7462, -2.7033, -2.7386, -2.8112, -2.7502, -2.9594,
-2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, -3.1191, -2.9893,
-3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509,
-3.5089, -3.4647, -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747,
-3.7072, -3.7279, -3.7283, -3.7795, -3.8259, -3.8447, -3.8663, -3.9182,
-3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, -4.1488, -4.1874,
-4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053,
-5.4927, -5.5712, -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103,
-6.0955, -6.1673, -6.2362, -6.3120, -6.3926, -6.4797, -6.5565, -6.6511,
-6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, -7.6136, -7.7469,
-7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861,
]
_MEL_CONFIGS = {
"16k": dict(sr=16_000, n_fft=1024, n_mels=80, hop=256, fmin=0, fmax=8_000,
target=_TARGET_MEAN_80D, log10=True),
"44k": dict(sr=44_100, n_fft=2048, n_mels=128, hop=512, fmin=0, fmax=22_050,
target=_TARGET_MEAN_128D, log10=False),
}
# ── Node 1: HF Smoother ───────────────────────────────────────────────────────
class SelvaHfSmoother:
"""Soft high-frequency attenuation for LoRA training clip preprocessing.
Blends a low-pass filtered copy of the audio with the original. Attenuates
the extreme HF content common in engineered sound design that the BigVGANv2
vocoder handles poorly, bringing the clip closer to the spectral region the
MMAudio codec was trained on (natural sounds with limited HF energy).
A blend of 0.7 at 12 kHz is a transparent starting point — audible only on
close comparison. Increase blend or lower cutoff if roundtrip quality is still
poor after spectral matching.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered. 0.7 is a transparent starting point.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Blends a low-pass filtered version of the audio with the original to gently attenuate "
"high-frequency content that the SelVA codec handles poorly. "
"Use before feature extraction to improve LoRA training targets. "
"Run after SelVA Spectral Matcher for best results."
)
def process(self, audio, cutoff_hz: float, blend: float):
waveform = audio["waveform"].float() # [1, C, L]
sr = audio["sample_rate"]
filtered = AF.lowpass_biquad(waveform, sr, cutoff_hz)
out = blend * filtered + (1.0 - blend) * waveform
# Preserve RMS level — LPF removes energy, keep the clip at its original loudness
rms_in = waveform.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = out.pow(2).mean().sqrt().clamp(min=1e-8)
out = out * (rms_in / rms_out)
peak = out.abs().max()
if peak > 1.0:
out = out / peak
print(f"[HF Smoother] cutoff={cutoff_hz:.0f} Hz blend={blend:.2f} "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f} "
f"peak={out.abs().max():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr},)
# ── Node 2: Spectral Matcher ──────────────────────────────────────────────────
class SelvaSpectralMatcher:
"""Adaptive per-band EQ toward the SelVA VAE's expected spectral distribution.
Computes the log-mel energy profile of the clip and compares it to the per-band
means stored in the VAE's normalization buffers (the statistics MMAudio was trained
on). Applies a smooth frequency-domain gain correction so the clip's spectral shape
matches what the codec expects, improving encode→decode roundtrip quality and
therefore LoRA training target quality.
The correction is additive in log space (multiplicative in linear), so it only
changes spectral balance — not the waveform's timing or phase structure.
max_gain_db clamps the correction to prevent extreme boosts on very quiet bands.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution. "
"0.8 is a good starting point.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB. Prevents extreme boosts on "
"very quiet frequency bands. 12 dB is conservative.",
}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Applies a smooth per-band gain correction to bring the audio's spectral profile "
"in line with the MMAudio VAE's expected distribution, derived from the per-band "
"normalization statistics baked into the VAE weights. "
"Use before feature extraction to improve LoRA training target quality. "
"Run before SelVA HF Smoother."
)
def process(self, audio, mode: str, strength: float, max_gain_db: float):
cfg = _MEL_CONFIGS[mode]
waveform = audio["waveform"].float() # [1, C, L]
sr_in = audio["sample_rate"]
sr_tgt = cfg["sr"]
n_fft = cfg["n_fft"]
hop = cfg["hop"]
# ── flatten to mono and resample if needed ────────────────────────────
wav = waveform[0].mean(0) # [L]
if sr_in != sr_tgt:
wav = AF.resample(wav.unsqueeze(0), sr_in, sr_tgt).squeeze(0)
device = wav.device
window = torch.hann_window(n_fft, device=device)
# ── STFT ──────────────────────────────────────────────────────────────
stft = torch.stft(wav, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True, return_complex=True) # [n_freqs, T]
mag = stft.abs() # [n_freqs, T]
# ── current log-mel mean per band ─────────────────────────────────────
fb = _mel_filterbank(sr_tgt, n_fft, cfg["n_mels"],
cfg["fmin"], cfg["fmax"]).to(device) # [n_mels, n_freqs]
mel_mag = torch.matmul(fb, mag).clamp(min=1e-5) # [n_mels, T]
if cfg["log10"]:
mel_log = torch.log10(mel_mag)
else:
mel_log = torch.log(mel_mag)
current_mean = mel_log.mean(dim=-1) # [n_mels]
target_mean = torch.tensor(cfg["target"], device=device) # [n_mels]
# ── per-mel-band gain (log space) ─────────────────────────────────────
mel_gain = (target_mean - current_mean) * strength # [n_mels]
# Clamp to ±max_gain_db
if cfg["log10"]:
max_log = max_gain_db / 20.0 # log10: 20 log10 = dB
else:
max_log = max_gain_db / 8.6859 # ln: 20 * log10(e) ≈ 8.686
mel_gain = mel_gain.clamp(-max_log, max_log)
# ── map mel gains → STFT frequency bins (weighted average) ────────────
fb_sum = fb.sum(0).clamp(min=1e-8) # [n_freqs]
freq_gain = (mel_gain @ fb) / fb_sum # [n_freqs]
if cfg["log10"]:
linear_gain = 10.0 ** freq_gain # [n_freqs]
else:
linear_gain = torch.exp(freq_gain) # [n_freqs]
# ── apply gain in frequency domain and reconstruct ───────────────────
stft_out = stft * linear_gain.unsqueeze(-1) # [n_freqs, T]
wav_out = torch.istft(stft_out, n_fft, hop_length=hop, win_length=n_fft,
window=window, center=True,
length=wav.shape[0]) # [L]
# ── resample back to original sr ──────────────────────────────────────
if sr_in != sr_tgt:
wav_out = AF.resample(wav_out.unsqueeze(0), sr_tgt, sr_in).squeeze(0)
# ── preserve original RMS level ───────────────────────────────────────
rms_in = wav.pow(2).mean().sqrt().clamp(min=1e-8)
rms_out = wav_out.pow(2).mean().sqrt().clamp(min=1e-8)
wav_out = wav_out * (rms_in / rms_out)
peak = wav_out.abs().max()
if peak > 1.0:
wav_out = wav_out / peak
# ── reshape to match input layout [1, C, L] ───────────────────────────
out = wav_out.unsqueeze(0).unsqueeze(0)
if waveform.shape[1] > 1:
out = out.expand(-1, waveform.shape[1], -1).clone()
gain_db_range = (
20.0 * torch.log10(linear_gain.clamp(min=1e-8))
)
print(f"[Spectral Matcher] mode={mode} strength={strength:.2f} "
f"gain [{gain_db_range.min():.1f}, {gain_db_range.max():.1f}] dB "
f"rms={rms_in:.4f}{out.pow(2).mean().sqrt():.4f}", flush=True)
return ({"waveform": out, "sample_rate": sr_in},)
-77
View File
@@ -1,77 +0,0 @@
"""SelVA BigVGAN Loader.
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
The model is modified in-place so ComfyUI's model cache is updated — no need to
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from .selva_bigvgan_trainer import inject_gafilters
class SelvaBigvganLoader:
CATEGORY = SELVA_CATEGORY
FUNCTION = "load"
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
DESCRIPTION = (
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
"Supports both 16k and 44k models. "
"Connect the output to SelVA Sampler instead of the base model loader."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"path": ("STRING", {
"default": "bigvgan_bj.pt",
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
"Relative paths resolve to ComfyUI output directory.",
}),
},
}
def load(self, model, path):
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
if "generator" not in ckpt:
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
mode = model["mode"]
if mode == "16k":
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
elif mode == "44k":
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
else:
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
# Remember device before injecting new modules (which default to CPU)
target_device = next(vocoder.parameters()).device
if ckpt.get("has_gafilter", False):
kernel_size = ckpt.get("gafilter_kernel_size", 9)
n_gaf = inject_gafilters(vocoder, kernel_size)
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
vocoder.load_state_dict(ckpt["generator"])
vocoder.to(target_device)
vocoder.eval()
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
return (model,)
-625
View File
@@ -1,625 +0,0 @@
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
Each experiment inherits from a shared `base` config and overrides specific keys.
Audio clips are loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image.
JSON format:
{
"name": "bigvgan_sweep",
"description": "optional note",
"data_dir": "/path/to/audio/clips",
"output_root": "/path/to/output",
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
...
]
}
"""
import copy
import csv
import json
import threading
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_bigvgan_trainer import (
_do_train,
_pregenerate_lora_mels,
_load_wav,
)
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
from .selva_lora_scheduler import (
_get_system_info,
_resolve_path,
_draw_comparison_curves,
)
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"train_mode": "snake_alpha_only",
"steps": 2000,
"lr": 1e-4,
"batch_size": 4,
"segment_seconds": 2.0,
"lambda_l2sp": 1e-3,
"use_gafilter": True,
"gafilter_kernel_size": 9,
"lambda_phase": 1.0,
"save_every": 500,
"seed": 42,
"discriminator_path": "",
"lora_adapter": "",
}
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge param defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _parse_training_log(log_path: Path) -> list:
"""Parse BigVGAN training CSV → list of total_loss values."""
losses = []
if not log_path.exists():
return losses
try:
with open(log_path) as f:
reader = csv.DictReader(f)
for row in reader:
losses.append(float(row["total_loss"]))
except Exception:
pass
return losses
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
"""Build {step: loss} at each save_every boundary.
Uses round-to-nearest to handle log_interval that doesn't divide
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
"""
result = {}
for target in range(save_every, total_steps + 1, save_every):
# loss_history[i] = loss at step (i+1)*log_interval
idx = round(target / log_interval) - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
class SelvaBigvganScheduler:
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
Audio clips are loaded once and reused across all experiments. Each experiment
deep-copies the vocoder and trains independently. Results are written to
`experiment_summary.json` after every completed run so partial results are
preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
"Audio clips are loaded once and reused across all experiments. "
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "bigvgan_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
mode = model["mode"]
dtype = model["dtype"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
strategy = model["strategy"]
if mode == "16k":
original_vocoder = feature_utils.tod.vocoder.vocoder
sample_rate = 16_000
elif mode == "44k":
original_vocoder = feature_utils.tod.vocoder
sample_rate = 44_100
else:
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[BigVGAN Scheduler] {description}", flush=True)
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load audio clips once
# ------------------------------------------------------------------
# Find minimum segment length across all experiments so we load enough
min_segment_seconds = float("inf")
for exp in spec["experiments"]:
cfg = _merge_config(base_cfg, exp)
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
min_segment_samples = int(min_segment_seconds * sample_rate)
audio_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
audio_files.extend(data_dir.rglob(ext))
if not audio_files:
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
clips = []
for af in audio_files:
try:
wav, sr = _load_wav(af)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0) # [L]
if wav.shape[0] >= min_segment_samples:
clips.append(wav.cpu())
else:
print(f" [BigVGAN Scheduler] Skip {af.name}: "
f"shorter than {min_segment_seconds}s", flush=True)
except Exception as e:
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
if not clips:
raise RuntimeError(
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
)
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
# ------------------------------------------------------------------
# 4. Offload unused components to free VRAM
# ------------------------------------------------------------------
comfy.model_management.unload_all_models()
feature_utils.to("cpu")
if "generator" in model:
model["generator"].to("cpu")
if "video_enc" in model:
model["video_enc"].to("cpu")
soft_empty_cache()
# ------------------------------------------------------------------
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
# ------------------------------------------------------------------
text_clip_cache = {}
any_lora = any(
_merge_config(base_cfg, exp).get("lora_adapter", "")
for exp in spec["experiments"]
)
if any_lora:
npz_files = sorted(data_dir.glob("*.npz"))
if npz_files:
prompt_map = {}
prompts_file = data_dir / "prompts.txt"
if prompts_file.exists():
for line in prompts_file.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "|" in line:
fname, prompt = line.split("|", 1)
prompt_map[fname.strip()] = prompt.strip()
default_prompt = data_dir.name
clip_model = feature_utils.clip_model
if clip_model is not None:
clip_model.to(device)
try:
for npz_path in npz_files:
data = dict(np.load(str(npz_path), allow_pickle=False))
prompt = prompt_map.get(
npz_path.name, data.get("prompt", default_prompt)
)
if isinstance(prompt, np.ndarray):
prompt = str(prompt)
tc = feature_utils.encode_text_clip([prompt])
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
finally:
if clip_model is not None:
clip_model.to("cpu")
soft_empty_cache()
if device.type == "cuda":
torch.cuda.empty_cache()
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
f"CLIP embeddings", flush=True)
# ------------------------------------------------------------------
# 6. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 100),
"start_step": 0,
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[BigVGAN Scheduler] Resuming — skipping "
f"{len(completed_ids)} completed: "
f"{sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[BigVGAN Scheduler] Could not read existing summary "
f"({e}) — starting fresh", flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": len(clips),
"experiments": [],
}
def _write_summary():
summary_path.write_text(
json.dumps(summary, indent=2), encoding="utf-8"
)
_write_summary()
# ------------------------------------------------------------------
# 7. Compute total steps for progress bar
# ------------------------------------------------------------------
total_steps = 0
for exp in spec["experiments"]:
if exp["id"] not in completed_ids:
cfg = _merge_config(base_cfg, exp)
total_steps += int(cfg.get("steps", 2000))
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
# ------------------------------------------------------------------
# 8. Run experiments in a worker thread
# ------------------------------------------------------------------
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
# inside torch.inference_mode(). inference_mode is thread-local — a
# new thread starts with it OFF, so all tensor operations produce
# normal autograd-compatible tensors.
_exc = [None]
def _worker():
try:
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
f"(already completed)", flush=True)
continue
cfg = _merge_config(base_cfg, exp)
# ── Extract experiment parameters ────────────────────
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
exp_steps = int(cfg.get("steps", 2000))
exp_lr = float(cfg.get("lr", 1e-4))
exp_bs = int(cfg.get("batch_size", 4))
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
exp_gafilter = bool(cfg.get("use_gafilter", True))
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
exp_phase = float(cfg.get("lambda_phase", 1.0))
exp_save = int(cfg.get("save_every", 500))
exp_seed = int(cfg.get("seed", 42))
exp_disc = str(cfg.get("discriminator_path", ""))
exp_lora = str(cfg.get("lora_adapter", ""))
segment_samples = int(exp_seg_s * sample_rate)
# Filter clips long enough for this experiment
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
if not exp_clips:
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
f"no clips >= {exp_seg_s}s", flush=True)
summary["experiments"].append({
"id": exp_id, "description": exp_desc,
"config": dict(cfg),
"results": {
"status": "failed",
"error": f"No clips >= {exp_seg_s}s",
"duration_seconds": 0,
},
"checkpoint_path": None,
"output_dir": str(output_root / exp_id),
})
_write_summary()
continue
# ── Resolve discriminator path ───────────────────────
disc_path = None
if exp_disc:
disc_path = Path(exp_disc.strip())
if not disc_path.is_absolute():
disc_path = (
Path(folder_paths.get_output_directory()) / disc_path
)
if not disc_path.exists():
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"discriminator not found: {disc_path}",
flush=True)
disc_path = None
# ── Pre-generate LoRA mels (disk-cached) ─────────────
lora_mel_pairs = None
if exp_lora:
lora_path = Path(exp_lora.strip())
if not lora_path.is_absolute():
lora_path = Path(folder_paths.base_path) / lora_path
if lora_path.exists():
seq_cfg = model["seq_cfg"]
lora_mel_pairs = _pregenerate_lora_mels(
model, data_dir, str(lora_path),
device, dtype, sample_rate,
seq_cfg.duration, seed=exp_seed,
cache_dir=str(output_root),
text_clip_cache=text_clip_cache,
)
if not lora_mel_pairs:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"no LoRA mel pairs generated",
flush=True)
lora_mel_pairs = None
if device.type == "cuda":
torch.cuda.empty_cache()
else:
print(f"[BigVGAN Scheduler] '{exp_id}': "
f"LoRA adapter not found: {lora_path}",
flush=True)
# ── Output dir ───────────────────────────────────────
exp_dir = output_root / exp_id
exp_dir.mkdir(parents=True, exist_ok=True)
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
flush=True)
if exp_desc:
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
print(f"[BigVGAN Scheduler] mode={train_mode} "
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"train_mode": train_mode, "steps": exp_steps,
"lr": exp_lr, "batch_size": exp_bs,
"segment_seconds": exp_seg_s,
"lambda_l2sp": exp_l2sp,
"use_gafilter": exp_gafilter,
"gafilter_kernel_size": exp_gaf_ks,
"lambda_phase": exp_phase,
"save_every": exp_save, "seed": exp_seed,
"discriminator_path": exp_disc,
"lora_adapter": exp_lora,
},
"results": {"status": "running"},
"checkpoint_path": None,
"output_dir": str(exp_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
# Ensure mel_converter is on device for this experiment
mel_converter.to(device)
# Fresh vocoder copy — _do_train modifies it in-place
vocoder_copy = copy.deepcopy(original_vocoder)
checkpoint_path = _do_train(
vocoder_copy, mel_converter, exp_clips,
device, dtype, strategy, feature_utils,
segment_samples, sample_rate,
train_mode, exp_steps, exp_lr, exp_bs,
exp_l2sp, exp_gafilter, exp_gaf_ks,
exp_phase, exp_save, exp_seed,
out_path, disc_path, pbar,
lora_mel_pairs,
)
duration = time.monotonic() - t_start
# Parse training CSV for loss history
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
loss_history = _parse_training_log(log_path)
log_interval = max(1, exp_steps // 20)
smoothed = (
_smooth_losses(loss_history)
if loss_history else []
)
final_loss = (
round(smoothed[-1], 6) if smoothed else None
)
min_loss = (
round(min(smoothed), 6) if smoothed else None
)
min_idx = (
smoothed.index(min(smoothed))
if smoothed else None
)
min_loss_step = (
(min_idx + 1) * log_interval
if min_idx is not None else None
)
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std = round(
float(np.std(loss_history[-quarter:])), 6
)
else:
loss_std = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval,
exp_save, exp_steps,
),
"loss_history": [
round(v, 6) for v in loss_history
],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["checkpoint_path"] = checkpoint_path
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
f"failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
finally:
# Clean up vocoder copy to free VRAM
soft_empty_cache()
_write_summary()
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if _exc[0] is not None:
raise _exc[0]
# ------------------------------------------------------------------
# 9. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[BigVGAN Scheduler] Sweep complete. "
f"Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 10. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
-106
View File
@@ -1,106 +0,0 @@
import json
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaDatasetBrowser:
"""Browse a dataset.json file entry by entry using an integer index.
Each entry in the JSON is expected to have:
- "path" : base path (no extension) — directory that holds frame images
- "label" : text description of the clip
Derived outputs:
- video_path : path + ".mp4"
- audio_path : path + ".wav"
- frames_dir : path (the directory itself, for image-sequence loaders)
- label : entry["label"]
- count : total number of entries in the file
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset_json": ("STRING", {
"default": "",
"tooltip": "Absolute or ComfyUI-relative path to a dataset.json file.",
}),
"index": ("INT", {
"default": 0,
"min": 0,
"max": 9999,
"step": 1,
"tooltip": "Zero-based index of the entry to inspect.",
}),
},
}
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "INT")
RETURN_NAMES = ("video_path", "audio_wav", "audio_flac", "features_path", "frames_dir", "mask_dir", "label", "max_index")
OUTPUT_TOOLTIPS = (
"path + '.mp4'",
"features/ + name + '.wav'",
"features/ + name + '.flac'",
"features/ + name + '.npz' (pre-extracted SelVA features)",
"path (image-sequence directory)",
"path + '_mask' (mask image-sequence directory)",
"Text label for this clip",
"count - 1 — wire to a primitive INT's max to constrain the index widget",
)
FUNCTION = "browse"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Reads a dataset.json produced by the SelVA dataset preparation pipeline "
"and exposes one entry at a time via an integer index. "
"Outputs the video path, audio path, frames directory, label, and total entry count."
)
# Re-read the file every call so edits are picked up without restarting ComfyUI.
IS_CHANGED = classmethod(lambda cls, **_: float("nan"))
def browse(self, dataset_json: str, index: int):
p = Path(dataset_json.strip())
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Dataset Browser] File not found: {p}")
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list) or len(data) == 0:
raise ValueError(f"[SelVA Dataset Browser] Expected a non-empty JSON array in {p}")
count = len(data)
if index >= count:
raise IndexError(
f"[SelVA Dataset Browser] index {index} is out of range "
f"(dataset has {count} entries, last index is {count - 1})"
)
entry = data[index]
base = entry["path"]
label = entry.get("label", "")
p_base = Path(base)
feat_base = str(p_base.parent / "features" / p_base.name)
print(
f"[SelVA Dataset Browser] {index + 1}/{count} label='{label}' base={base}",
flush=True,
)
return (
base + ".mp4",
feat_base + ".wav",
feat_base + ".flac",
feat_base + ".npz",
base,
base + "_mask",
label,
count - 1,
)
-788
View File
@@ -1,788 +0,0 @@
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
Typical chain:
SelvaDatasetLoader
↓ AUDIO_DATASET
SelvaDatasetResampler (optional)
↓ AUDIO_DATASET
SelvaDatasetLUFSNormalizer (optional)
↓ AUDIO_DATASET
SelvaDatasetCompressor (optional)
↓ AUDIO_DATASET
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
↓ AUDIO_DATASET
SelvaDatasetHfSmoother (optional — batch HF attenuation)
↓ AUDIO_DATASET
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
↓ AUDIO_DATASET
SelvaDatasetInspector (optional)
↓ AUDIO_DATASET + STRING report
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
"""
from pathlib import Path
import numpy as np
import torch
import torchaudio
from .utils import SELVA_CATEGORY
# ComfyUI custom type name — passed between all dataset pipeline nodes
AUDIO_DATASET = "AUDIO_DATASET"
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg
def _load_audio(path: Path):
"""Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues."""
if path.suffix.lower() in _SOUNDFILE_EXTS:
import soundfile as sf
wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C]
wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L]
else:
wav, sr = torchaudio.load(str(path)) # [C, L]
wav = wav.unsqueeze(0).float() # [1, C, L]
return wav, sr
class SelvaDatasetLoader:
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"folder": ("STRING", {
"default": "",
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
def load(self, folder: str):
folder = Path(folder.strip())
if not folder.exists():
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
if not files:
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
dataset = []
for f in sorted(files):
try:
wav, sr = _load_audio(f)
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
except Exception as e:
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
return (dataset,)
class SelvaDatasetResampler:
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_sr": ("INT", {
"default": 44100, "min": 8000, "max": 192000,
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "resample"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
def resample(self, dataset, target_sr: int):
import soxr
out = []
changed = 0
for item in dataset:
sr = item["sample_rate"]
if sr == target_sr:
out.append(item)
continue
wav = item["waveform"][0] # [C, L]
# soxr expects [L, C] (time-first), float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
changed += 1
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
return (out,)
class SelvaDatasetLUFSNormalizer:
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"target_lufs": ("FLOAT", {
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
}),
"true_peak_dbtp": ("FLOAT", {
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "normalize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
)
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
import pyloudnorm as pyln
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
out = []
skipped = 0
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
meter = pyln.Meter(sr)
try:
loudness = meter.integrated_loudness(wav_np)
except Exception:
skipped += 1
out.append(item)
continue
if not np.isfinite(loudness):
skipped += 1
out.append(item)
continue
gain_db = target_lufs - loudness
gain_linear = 10.0 ** (gain_db / 20.0)
wav_norm = wav * gain_linear
# True peak limit
peak = wav_norm.abs().max().item()
if peak > tp_linear:
wav_norm = wav_norm * (tp_linear / peak)
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
print(
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
flush=True,
)
return (out,)
class SelvaDatasetCompressor:
"""Apply mild parallel compression to reduce within-clip loudness variance.
Uses pedalboard.Compressor (2:13:1 ratio). Parallel (New York) style:
blends compressed signal with dry so transients are preserved while
the dynamic range is gently tightened. Apply after LUFS normalization.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"threshold_db": ("FLOAT", {
"default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.",
}),
"ratio": ("FLOAT", {
"default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5,
"tooltip": "Compression ratio. 2:13:1 is mild; stay below 4:1 to avoid pumping.",
}),
"attack_ms": ("FLOAT", {
"default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0,
"tooltip": "Attack time in ms. Slower attack preserves transients.",
}),
"release_ms": ("FLOAT", {
"default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0,
"tooltip": "Release time in ms.",
}),
"mix": ("FLOAT", {
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.30.5 is typical.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "compress"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Mild parallel compression to reduce within-clip dynamic range. "
"Blends compressed signal with dry at the given mix ratio. "
"Apply after LUFS normalization."
)
def compress(self, dataset, threshold_db: float, ratio: float,
attack_ms: float, release_ms: float, mix: float):
from pedalboard import Compressor, Pedalboard
board = Pedalboard([Compressor(
threshold_db=threshold_db,
ratio=ratio,
attack_ms=attack_ms,
release_ms=release_ms,
)])
out = []
for item in dataset:
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# pedalboard expects [C, L] float32 numpy
wav_np = wav.float().numpy() # [C, L]
compressed = board(wav_np, sr) # [C, L]
mixed = (1.0 - mix) * wav_np + mix * compressed
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L]
out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]})
print(
f"[DatasetCompressor] {len(out)} clips compressed "
f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}",
flush=True,
)
return (out,)
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
Method: compare mean energy in 15 kHz band vs 1520 kHz band via STFT.
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
"""
if sr < 32000:
return False # can't assess HF at low sample rates
n_fft = 2048
hop = 512
mono = wav[0].mean(0) # [L]
window = torch.hann_window(n_fft, device=mono.device)
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
band_lo = (freqs >= 1000) & (freqs < 5000)
band_hi = (freqs >= 15000) & (freqs < 20000)
if band_hi.sum() == 0:
return False
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
return ratio_db > 40.0
def _estimate_snr(wav: torch.Tensor) -> float:
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
mono = wav[0].mean(0) # [L]
if mono.shape[0] < 2048:
return 60.0 # clip too short to frame — assume clean
frames = mono.unfold(0, 2048, 512) # [N, 2048]
rms = frames.pow(2).mean(-1).sqrt() # [N]
p95 = torch.quantile(rms, 0.95).item()
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
return 20.0 * np.log10(p95 / p05 + 1e-8)
class SelvaDatasetInspector:
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"skip_rejected": ("BOOLEAN", {
"default": True,
"tooltip": "If True, flagged clips are removed from the output dataset. "
"If False, all clips pass through but the report still lists issues.",
}),
"min_snr_db": ("FLOAT", {
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
"tooltip": "Clips with estimated SNR below this value are flagged.",
}),
"check_codec_artifacts": ("BOOLEAN", {
"default": True,
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
}),
"max_silence_fraction": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
"(< -60 dBFS). Set to 0 to disable silence detection.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET, "STRING")
RETURN_NAMES = ("dataset", "report")
FUNCTION = "inspect"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Analyze each clip for clipping, low SNR, and codec artifacts. "
"Outputs a filtered AUDIO_DATASET and a text report. "
"Connect report to a ShowText node to preview in the UI."
)
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
clean = []
flagged = []
lines = ["SelVA Dataset Inspector Report", "=" * 40]
for item in dataset:
wav = item["waveform"]
sr = item["sample_rate"]
name = item["name"]
issues = []
# Clipping
peak = wav.abs().max().item()
if peak > 0.99:
issues.append(f"clipping (peak={peak:.3f})")
# Low SNR
snr = _estimate_snr(wav)
if snr < min_snr_db:
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
# Codec artifacts
if check_codec_artifacts and _check_hf_shelf(wav, sr):
issues.append("codec artifact (HF shelf > 15 kHz)")
# Silence detection
if max_silence_fraction > 0:
mono = wav[0].mean(0)
if mono.shape[0] >= 2048:
frames = mono.unfold(0, 2048, 512)
rms = frames.pow(2).mean(-1).sqrt()
silent_frac = (rms < 1e-3).float().mean().item()
if silent_frac > max_silence_fraction:
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
if issues:
flagged.append(name)
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
if not skip_rejected:
clean.append(item)
else:
clean.append(item)
lines.append(f" OK {name}")
lines.append("=" * 40)
lines.append(
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
+ (" (removed)" if skip_rejected else " (kept)")
)
report = "\n".join(lines)
print(f"[DatasetInspector]\n{report}", flush=True)
return (clean, report)
class SelvaDatasetItemExtractor:
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
Bridges the dataset pipeline to any node that accepts a standard AUDIO
input — save audio, HF Smoother, Spectral Matcher, etc.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"index": ("INT", {
"default": 0, "min": 0, "max": 9999,
"tooltip": "0-based index. Wraps around if index >= dataset length.",
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "INT")
RETURN_NAMES = ("audio", "name", "total")
FUNCTION = "extract"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Extract one clip from an AUDIO_DATASET by index. "
"Returns standard AUDIO (compatible with all audio nodes), "
"the clip name, and the total dataset length."
)
def extract(self, dataset, index: int):
if not dataset:
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
idx = index % len(dataset)
item = dataset[idx]
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
print(
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
flush=True,
)
return (audio, item["name"], len(dataset))
class SelvaDatasetSaver:
"""Save all clips in an AUDIO_DATASET to disk as FLAC files.
Optionally copies matching .npz feature files from a source directory,
keeping FLAC/NPZ pairs in sync after the inspector has filtered clips.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"output_dir": ("STRING", {
"default": "",
"tooltip": "Absolute path to output folder. Created if it does not exist.",
}),
},
"optional": {
"npz_source_dir": ("STRING", {
"default": "",
"tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. "
"Missing NPZs are warned but do not abort the save.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("report",)
OUTPUT_NODE = True
FUNCTION = "save"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Save every clip in an AUDIO_DATASET to output_dir as FLAC. "
"If npz_source_dir is provided, copies the matching .npz file for each clip — "
"so rejected clips never get their NPZ copied."
)
def save(self, dataset, output_dir: str, npz_source_dir: str = ""):
import shutil
import soundfile as sf
out = Path(output_dir.strip())
out.mkdir(parents=True, exist_ok=True)
npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None
saved = 0
npz_copied = 0
npz_missing = []
for item in dataset:
name = item["name"]
wav = item["waveform"][0] # [C, L]
sr = item["sample_rate"]
# soundfile wants [L] mono or [L, C] multichannel, float32
wav_np = wav.permute(1, 0).float().numpy() # [L, C]
if wav_np.shape[1] == 1:
wav_np = wav_np[:, 0] # [L] mono
flac_path = out / f"{name}.flac"
sf.write(str(flac_path), wav_np, sr, subtype="PCM_24")
saved += 1
if npz_src is not None:
# Augmented clips store their origin name — use it to find the .npz
lookup = item.get("origin_name", name)
npz_path = npz_src / f"{lookup}.npz"
if npz_path.exists():
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
npz_copied += 1
else:
npz_missing.append(name)
lines = [
f"[DatasetSaver] Saved {saved} clips → {out}",
]
if npz_src is not None:
lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}")
for n in npz_missing:
lines.append(f" MISSING NPZ: {n}")
report = "\n".join(lines)
print(report, flush=True)
return (report,)
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
class SelvaDatasetSpectralMatcher:
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"mode": (["44k", "16k"], {
"tooltip": "Must match the SelVA model you are training. "
"44k = large model, 16k = small model.",
}),
"strength": ("FLOAT", {
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
}),
"max_gain_db": ("FLOAT", {
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
"tooltip": "Clamps per-band gain to ±dB.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply adaptive spectral matching to every clip in a dataset. "
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
"VAE's expected distribution."
)
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
from .selva_audio_preprocessors import SelvaSpectralMatcher
matcher = SelvaSpectralMatcher()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = matcher.process(audio, mode, strength, max_gain_db)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
f"mode={mode} strength={strength}", flush=True)
return (out,)
class SelvaDatasetHfSmoother:
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
Wraps SelvaHfSmoother so it works on batch datasets instead of
individual AUDIO items. Same parameters — see that node for details.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"cutoff_hz": ("FLOAT", {
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
}),
"blend": ("FLOAT", {
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "0 = original, 1 = fully filtered.",
}),
}
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "process"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Apply soft HF attenuation to every clip in a dataset. "
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
"with the original to tame extreme HF content."
)
def process(self, dataset, cutoff_hz: float, blend: float):
from .selva_audio_preprocessors import SelvaHfSmoother
smoother = SelvaHfSmoother()
out = []
for item in dataset:
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
(result,) = smoother.process(audio, cutoff_hz, blend)
new_item = dict(item) # preserve origin_name and any extra keys
new_item["waveform"] = result["waveform"]
new_item["sample_rate"] = result["sample_rate"]
out.append(new_item)
print(f"[DatasetHfSmoother] {len(out)} clips processed "
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
return (out,)
# ── Dataset augmenter ────────────────────────────────────────────────────────
class SelvaDatasetAugmenter:
"""Create augmented variants of each clip to expand a small dataset.
Supports gain variation (always available) and optionally pitch shift
and time stretch via audiomentations. Install audiomentations for the
full feature set: ``pip install audiomentations``
Each original clip produces ``variants_per_clip`` augmented copies.
Originals are kept by default (toggle ``keep_originals``).
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"dataset": (AUDIO_DATASET,),
"variants_per_clip": ("INT", {
"default": 2, "min": 1, "max": 20,
"tooltip": "Number of augmented copies per original clip.",
}),
"gain_range_db": ("FLOAT", {
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
}),
"seed": ("INT", {"default": 42}),
},
"optional": {
"pitch_range_semitones": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
}),
"time_stretch_range": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
"tooltip": "Random time stretch ±fraction (0.1 = 90%110% speed). "
"Requires audiomentations. 0 = disabled.",
}),
"keep_originals": ("BOOLEAN", {
"default": True,
"tooltip": "Include the original unaugmented clips in the output.",
}),
},
}
RETURN_TYPES = (AUDIO_DATASET,)
RETURN_NAMES = ("dataset",)
FUNCTION = "augment"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Create augmented variants of each clip (gain, pitch, time stretch) "
"to expand small training datasets. Gain is always available; pitch and "
"time stretch require audiomentations (pip install audiomentations)."
)
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
seed: int, pitch_range_semitones: float = 0.0,
time_stretch_range: float = 0.0, keep_originals: bool = True):
rng = np.random.RandomState(seed)
# Try audiomentations for pitch/stretch
use_am = False
am_compose = None
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
if needs_am:
try:
import audiomentations as am
transforms = []
if pitch_range_semitones > 0:
transforms.append(am.PitchShift(
min_semitones=-pitch_range_semitones,
max_semitones=pitch_range_semitones,
p=0.5,
))
if time_stretch_range > 0:
transforms.append(am.TimeStretch(
min_rate=1.0 - time_stretch_range,
max_rate=1.0 + time_stretch_range,
leave_length_unchanged=True,
p=0.5,
))
am_compose = am.Compose(transforms)
use_am = True
except ImportError:
print("[DatasetAugmenter] audiomentations not installed — "
"pitch_shift and time_stretch disabled. "
"Install: pip install audiomentations", flush=True)
out = []
if keep_originals:
out.extend(dataset)
for item in dataset:
wav = item["waveform"] # [1, C, L]
sr = item["sample_rate"]
name = item["name"]
for v in range(variants_per_clip):
# Gain variation (always applied)
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
gain_lin = 10.0 ** (gain_db / 20.0)
wav_aug = wav * gain_lin
# Pitch/stretch via audiomentations
if use_am and am_compose is not None:
wav_np = wav_aug[0].numpy() # [C, L] float32
if wav_np.shape[0] == 1:
wav_np = wav_np[0] # [L] mono for audiomentations
wav_np = am_compose(samples=wav_np, sample_rate=sr)
if wav_np.ndim == 1:
wav_np = wav_np[np.newaxis, :] # back to [1, L]
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
# Prevent clipping
peak = wav_aug.abs().max()
if peak > 1.0:
wav_aug = wav_aug / peak
out.append({
"waveform": wav_aug,
"sample_rate": sr,
"name": f"{name}_aug{v:02d}",
"origin_name": name,
})
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
f"gain=±{gain_range_db:.1f}dB"
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
flush=True)
return (out,)
-515
View File
@@ -1,515 +0,0 @@
"""SelVA DITTO Optimizer.
Inference-time noise optimization: optimizes the initial noise latent x_0
using a style loss against target style reference clips, backpropagating through the
ODE solver. All model weights remain frozen — only x_0 changes.
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
against target style reference clips. Runs entirely before the vocoder — optimization
only requires the DiT + VAE decoder, not BigVGAN.
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
forward pass activations) instead of O(N steps). Backward recomputes each
step's activations on demand.
"""
import dataclasses
import threading
from pathlib import Path
import torch
import torch.nn.functional as F
import torchaudio
import comfy.utils
import comfy.model_management
import folder_paths
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
def _load_wav(path):
"""Load audio file to [channels, samples] float32 tensor."""
try:
return torchaudio.load(str(path))
except Exception:
pass
import soundfile as sf
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
wav = torch.from_numpy(data.T)
return wav, sr
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss between generated mel and precomputed reference statistics.
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
ref_mean: [n_mels] mean spectrum of reference clips (detached)
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
Start at 0; enable only if mean-only optimization converges
without noise, then increase slowly (0.010.1).
"""
m = mel_gen.squeeze(0) # [n_mels, T]
# Mean spectrum loss — captures spectral envelope
gen_mean = m.mean(dim=-1) # [n_mels]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix loss — captures timbral texture (can add noise if too high)
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
"""Style loss computed directly in VAE latent space.
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
ref_mean: [C_lat] mean latent vector of reference clips
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
gram_weight: weight for Gram component — 0 = mean only (recommended start)
Operating in latent space avoids backprop through the VAE decoder, which
is @torch.inference_mode() and produces noisy, unstable gradients.
"""
# Mean latent loss — matches average activation per channel
gen_mean = z.mean(dim=0) # [C_lat]
loss_mean = F.l1_loss(gen_mean, ref_mean)
if gram_weight <= 0.0:
return loss_mean
# Gram matrix — inter-channel covariance, position-invariant
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
loss_gram = F.mse_loss(gram_gen, ref_gram)
return loss_mean + gram_weight * loss_gram
class SelvaDittoOptimizer:
"""DITTO inference-time noise optimization.
Freezes all model weights and optimizes only the initial noise latent x_0
to make the generated audio sound like the target style reference clips.
No training data or gradient updates to the model — per-video per-run.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"features": ("SELVA_FEATURES",),
"prompt": ("STRING", {
"default": "", "multiline": True,
"tooltip": "Sound description. Leave empty to use features prompt.",
}),
"negative_prompt": ("STRING", {
"default": "", "multiline": False,
}),
"reference_dir": ("STRING", {
"default": "",
"tooltip": "Directory with target style reference audio files (.wav/.flac/.mp3). "
"Reference mel statistics are precomputed from these once.",
}),
"n_opt_steps": ("INT", {
"default": 50, "min": 5, "max": 500,
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
"each step requires ~2 DiT forward passes.",
}),
"opt_lr": ("FLOAT", {
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
"tooltip": "Adam learning rate for x_0 optimization. "
"0.020.05 is recommended; 0.1 (paper default) causes oscillation.",
}),
"n_ode_steps": ("INT", {
"default": 10, "min": 5, "max": 50,
"tooltip": "Euler ODE steps run during each optimization iteration. "
"Lower = faster optimization (1015 is a good trade-off). "
"Final generation always uses the steps parameter below.",
}),
"n_grad_steps": ("INT", {
"default": 5, "min": 1, "max": 50,
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
"Higher = more accurate gradient, more VRAM. "
"Must be ≤ n_ode_steps. 5 is a good default.",
}),
"style_weight": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
"tooltip": "Weight of the target style style loss. High values push harder toward "
"target style style but add noise. Start at 0.1 and increase slowly.",
}),
"gram_weight": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
"0.1 adds texture matching but can introduce white noise.",
}),
"anchor_weight": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
"Prevents optimization from pushing x0 out of the flow's "
"expected distribution (which causes white noise). "
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
}),
"steps": ("INT", {
"default": 25, "min": 1, "max": 200,
"tooltip": "Euler steps for the final generation pass (after optimization).",
}),
"cfg_strength": ("FLOAT", {
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
"optional": {
"normalize": ("BOOLEAN", {"default": True}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward target style style.",)
FUNCTION = "optimize"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"DITTO inference-time noise optimization (arXiv:2401.12179). "
"Optimizes the initial noise latent x_0 to match target style reference clips "
"via mel statistics style loss, backpropagating through the ODE. "
"All model weights frozen — zero quality degradation risk."
)
def optimize(self, model, features, prompt, negative_prompt,
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize=True, target_lufs=-27.0):
import traceback
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
net_generator = model["generator"]
feature_utils = model["feature_utils"]
mel_converter = feature_utils.mel_converter
# Validate variant match
feat_variant = features.get("variant")
if feat_variant is not None and feat_variant != model["variant"]:
raise ValueError(
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
f"Re-run Feature Extractor."
)
if not prompt or not prompt.strip():
prompt = features.get("prompt", "")
# Resolve duration and seq_cfg
duration = features.get("duration", 0)
if duration <= 0:
raise ValueError("[DITTO] Features contain no duration field.")
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
sample_rate = seq_cfg.sampling_rate
# Load reference clips and encode to latent space.
# Style loss is computed in latent space (after net_generator.unnormalize)
# rather than mel space — this avoids backpropagating through the VAE
# decoder (which is @torch.inference_mode() and produces noisy gradients).
ref_dir = Path(reference_dir.strip())
if not ref_dir.is_absolute():
ref_dir = Path(folder_paths.models_dir) / ref_dir
if not ref_dir.exists():
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
ref_files = []
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
ref_files.extend(ref_dir.rglob(ext))
if not ref_files:
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
if not hasattr(feature_utils.tod.vae, "encoder"):
raise RuntimeError(
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
"Reload the model with the encoder enabled."
)
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
mel_converter.to(device, torch.float32) # cuFFT requires float32
ref_latents = []
with torch.no_grad():
for rf in ref_files:
try:
wav, sr = _load_wav(rf)
if wav.shape[0] > 1:
wav = wav.mean(0, keepdim=True)
if sr != sample_rate:
wav = torchaudio.functional.resample(wav, sr, sample_rate)
wav = wav.squeeze(0).to(device, torch.float32)
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
except Exception as e:
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
if not ref_latents:
raise RuntimeError("[DITTO] No usable reference clips.")
# Precompute reference latent statistics (done once — detached, no grad)
with torch.no_grad():
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
ref_mean = all_means.mean(0) # [C_lat]
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
f"grad_steps={n_grad_steps}", flush=True)
if strategy == "offload_to_cpu":
net_generator.to(device)
feature_utils.to(device)
soft_empty_cache()
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
_result = [None]
_exc = [None]
def _worker():
try:
_result[0] = _do_optimize(
net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar,
)
except Exception as e:
_exc[0] = e
traceback.print_exc()
t = threading.Thread(target=_worker, daemon=True)
t.start()
t.join()
if strategy == "offload_to_cpu":
net_generator.to(get_offload_device())
feature_utils.to(get_offload_device())
soft_empty_cache()
if _exc[0] is not None:
raise _exc[0]
return (_result[0],)
def _do_optimize(net_generator, feature_utils, mel_converter,
features, prompt, negative_prompt,
ref_mean, ref_gram,
seq_cfg, sample_rate, device, dtype,
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
normalize, target_lufs, pbar):
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
# Strip inference flags from ref stats (came from main thread) and cast to
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
ref_mean = ref_mean.clone().detach().to(dtype)
ref_gram = ref_gram.clone().detach().to(dtype)
torch.manual_seed(seed)
clip_f = features["clip_features"].to(device, dtype).clone()
sync_f = features["sync_features"].to(device, dtype).clone()
# Strip inference-mode flags from all model weights and buffers BEFORE any
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
# operations on inference tensors produce inference tensors, so conditions
# computed from tainted weights would also be tainted. clone() outside
# inference_mode produces a normal tensor regardless of the source flag.
def _strip_inference(module):
for mod in module.modules():
for name, buf in list(mod._buffers.items()):
if buf is not None:
mod._buffers[name] = buf.clone()
for name, param in list(mod._parameters.items()):
if param is not None:
mod._parameters[name] = torch.nn.Parameter(
param.data.clone(), requires_grad=False
)
_strip_inference(net_generator)
_strip_inference(feature_utils)
_strip_inference(mel_converter)
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=clip_f.shape[1],
sync_seq_len=sync_f.shape[1],
)
with torch.no_grad():
text_clip = feature_utils.encode_text_clip([prompt])
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip
)
# Clone all tensors inside conditions/empty_conditions to ensure no inference
# flags survived from intermediate computations inside preprocess_conditions.
def _clone_nested(obj):
if isinstance(obj, torch.Tensor):
return obj.clone()
elif isinstance(obj, dict):
return {k: _clone_nested(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return type(obj)(_clone_nested(v) for v in obj)
return obj
conditions = _clone_nested(conditions)
empty_conditions = _clone_nested(empty_conditions)
# Initial noise — x_0 is the parameter we optimize
x0_init = torch.randn(
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
device=device, dtype=dtype,
)
x0 = torch.nn.Parameter(x0_init.clone())
x0_init = x0_init.detach() # anchor — kept fixed, no grad
optimizer = torch.optim.Adam([x0], lr=opt_lr)
# n_grad_steps must not exceed n_ode_steps
n_grad_steps = min(n_grad_steps, n_ode_steps)
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
print(f"[DITTO] Optimizing x_0 "
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
# Freeze all model weights (double-check — should already be frozen at inference)
net_generator.requires_grad_(False)
feature_utils.requires_grad_(False)
mel_converter.requires_grad_(False)
for opt_step in range(n_opt_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
# Detach from x0 so Phase 1 does not build a computation graph.
with torch.no_grad():
x = x0.detach()
for i in range(n_free_steps):
t = ts[i]
dt = ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
# Straight-through estimator: reconnect x to x0's gradient path by
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
# creates a grad_fn pointing back to x0, so loss.backward() will
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
# treated as identity for gradient purposes (truncated BPTT).
#
# x may carry an inference tensor flag from Phase 1 (derived from
# conditions which were built outside inference_mode but may have
# propagated the flag). .clone() strips it so the STE addition does
# not try to save an inference tensor for backward.
x = x.clone()
x = x + (x0 - x0.detach())
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
for i in range(n_free_steps, n_ode_steps):
t = ts[i]
dt = ts[i + 1] - t
# Gradient checkpointing: recompute forward during backward,
# avoiding storage of DiT activations for each step.
def _ode_step(x_in, t=t):
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
flow = torch.utils.checkpoint.checkpoint(
_ode_step, x, use_reentrant=False
)
x = x + dt * flow
# ── Style loss in latent space ───────────────────────────────────────
# Unnormalize x back to VAE latent space — fully differentiable, no
# decode needed. ref_mean/ref_gram are computed from encoded reference
# clips in the same space. Avoids backprop through VAE decoder which
# is @torch.inference_mode() and produces noisy gradients.
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
# the ODE into an out-of-distribution region that decodes as white noise.
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
loss = style_loss + anchor_loss
optimizer.zero_grad()
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
torch.nn.utils.clip_grad_norm_([x0], 1.0)
optimizer.step()
pbar.update(1)
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
f"x0_std={x0.data.std().item():.3f}", flush=True)
# ── Final generation with optimized x_0 ─────────────────────────────────
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
with torch.no_grad():
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
x = x0.detach()
for i in range(steps):
comfy.model_management.throw_exception_if_processing_interrupted()
t = fm_ts[i]
dt = fm_ts[i + 1] - t
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
x = x + dt * flow
pbar.update(1)
x1_unnorm = net_generator.unnormalize(x)
spec = feature_utils.decode(x1_unnorm)
audio = feature_utils.vocode(spec)
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
flush=True)
audio = audio.float()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
+101 -25
View File
@@ -35,6 +35,66 @@ def _resize_frames(frames, size):
return x.clamp(0.0, 1.0) # [N, C, H, W] return x.clamp(0.0, 1.0) # [N, C, H, W]
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
"""
Compute a bounding box around the union of all mask frames.
mask: [M, H', W'] float [0,1]
square: if True, expand bbox to a square and shift into frame bounds;
if False, apply margin independently on each axis (rect crop).
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
"""
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
m = F.interpolate(
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
).squeeze(1)
else:
m = mask.float()
union = (m > 0.5).max(dim=0).values # [H, W] bool
if not union.any():
if square:
# Empty mask — center square crop
side = min(frame_h, frame_w)
cy, cx = frame_h // 2, frame_w // 2
y0 = max(0, cy - side // 2)
x0 = max(0, cx - side // 2)
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
else:
# Empty mask — return full frame (no meaningful rect to crop to)
return 0, 0, frame_h, frame_w
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
y0, y1 = int(ys[0]), int(ys[-1]) + 1
x0, x1 = int(xs[0]), int(xs[-1]) + 1
if square:
side = max(y1 - y0, x1 - x0)
pad = int(side * margin)
side += 2 * pad
cy = (y0 + y1) // 2
cx = (x0 + x1) // 2
y0n = cy - side // 2
x0n = cx - side // 2
y1n = y0n + side
x1n = x0n + side
# Shift into frame bounds to preserve square shape
if y0n < 0: y1n -= y0n; y0n = 0
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
if x0n < 0: x1n -= x0n; x0n = 0
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
else:
pad_y = int(max(1, y1 - y0) * margin)
pad_x = int(max(1, x1 - x0) * margin)
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0): def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
""" """
Apply a ComfyUI MASK to resized frames. Apply a ComfyUI MASK to resized frames.
@@ -68,20 +128,9 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
return frames * alpha + 0.5 * (1.0 - alpha) return frames * alpha + 0.5 * (1.0 - alpha)
def _resolve_named_path(cache_dir: str, name: str) -> str:
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
# Sanitize: replace path separators so the name stays inside cache_dir
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
i = 1
while True:
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
if not os.path.exists(p):
return p
i += 1
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None, def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True): mask_strength=1.0, mask_clip=True, mask_sync=True,
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
h = hashlib.sha256() h = hashlib.sha256()
raw = video_tensor.cpu().numpy().tobytes() raw = video_tensor.cpu().numpy().tobytes()
n = len(raw) n = len(raw)
@@ -99,6 +148,10 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
h.update(str(round(mask_strength, 4)).encode()) h.update(str(round(mask_strength, 4)).encode())
h.update(str(mask_clip).encode()) h.update(str(mask_clip).encode())
h.update(str(mask_sync).encode()) h.update(str(mask_sync).encode())
h.update(str(crop_to_mask).encode())
h.update(str(crop_rect).encode())
if crop_to_mask or crop_rect:
h.update(str(round(crop_margin, 4)).encode())
h.update(prompt.encode()) h.update(prompt.encode())
h.update(str(fps).encode()) h.update(str(fps).encode())
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
@@ -128,8 +181,6 @@ class SelvaFeatureExtractor:
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}), "tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
"cache_dir": ("STRING", {"default": "", "cache_dir": ("STRING", {"default": "",
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}), "tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
"name": ("STRING", {"default": "",
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
"mask": ("MASK", { "mask": ("MASK", {
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.", "tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
}), }),
@@ -145,6 +196,18 @@ class SelvaFeatureExtractor:
"default": True, "default": True,
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.", "tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
}), }),
"crop_to_mask": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
}),
"crop_rect": ("BOOLEAN", {
"default": False,
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
}),
"crop_margin": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
}),
}, },
} }
@@ -155,14 +218,14 @@ class SelvaFeatureExtractor:
"Source fps of the video — wire to VHS_VideoCombine frame_rate.", "Source fps of the video — wire to VHS_VideoCombine frame_rate.",
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.", "The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
) )
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
FUNCTION = "extract_features" FUNCTION = "extract_features"
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant." DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
def extract_features(self, model, video, prompt, video_info=None, fps=30.0, def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir="", name="", mask=None, duration=0.0, cache_dir="", mask=None,
mask_strength=1.0, mask_clip=True, mask_sync=True): mask_strength=1.0, mask_clip=True, mask_sync=True,
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
if video_info is not None: if video_info is not None:
fps = video_info["loaded_fps"] fps = video_info["loaded_fps"]
@@ -178,15 +241,11 @@ class SelvaFeatureExtractor:
if not cache_dir: if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features") cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
if name.strip():
# Named mode: always extract and save to an incremented filename
cached_path = _resolve_named_path(cache_dir, name.strip())
else:
# Hash mode: skip extraction if identical inputs were already processed
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask, cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync) mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
cached_path = os.path.join(cache_dir, f"{cache_key}.npz") cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
if os.path.exists(cached_path): if os.path.exists(cached_path):
print(f"[SelVA] Using cached features: {cached_path}", flush=True) print(f"[SelVA] Using cached features: {cached_path}", flush=True)
cached = _load_cached(cached_path) cached = _load_cached(cached_path)
@@ -206,10 +265,24 @@ class SelvaFeatureExtractor:
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True) print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
pbar = comfy.utils.ProgressBar(3) pbar = comfy.utils.ProgressBar(3)
# Pre-compute crop bbox once from the original-resolution mask
crop_bbox = None
if mask is not None and (crop_to_mask or crop_rect):
H_vid, W_vid = video.shape[1], video.shape[2]
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
cy0, cx0, cy1, cx1 = crop_bbox
_mode = "square" if _square else "rect"
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
try: try:
with torch.no_grad(): with torch.no_grad():
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] --- # --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C] clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384] clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
if mask is not None and mask_clip: if mask is not None and mask_clip:
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength) clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
@@ -222,6 +295,9 @@ class SelvaFeatureExtractor:
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] --- # --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C] sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
if crop_bbox is not None:
cy0, cx0, cy1, cx1 = crop_bbox
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224] sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
if mask is not None and mask_sync: if mask is not None and mask_sync:
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength) sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
-421
View File
@@ -1,421 +0,0 @@
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
JSON format:
{
"name": "eval_batch_1",
"data_dir": "/path/to/features",
"output_dir": "/path/to/evals/batch1",
"steps": 25,
"seed": 42,
"adapters": [
{"id": "baseline"},
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
]
}
Empty / missing "path" = baseline (no LoRA applied).
"""
import copy
import json
import sys
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_pil_to_tensor,
_find_audio,
_load_audio,
)
from selva_core.model.lora import apply_lora, load_lora
def _avg_metrics(metrics_list: list) -> dict:
"""Average spectral metrics across multiple clips, ignoring None entries."""
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
"spectral_flatness", "temporal_variance"]
valid = [m for m in metrics_list if m]
if not valid:
return {}
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _safe_stem(adapter_id: str) -> str:
"""Replace characters illegal in filenames."""
for ch in r'/\:*?"<>|':
adapter_id = adapter_id.replace(ch, "_")
return adapter_id
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
METRICS = [
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
("spectral_flatness", "Spectral Flatness"),
("temporal_variance", "Temporal Variance"),
]
COLORS = [
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
]
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
for ax, (key, title) in zip(axes, METRICS):
values = []
colors = []
for i, m in enumerate(metrics_list):
v = m.get(key, 0.0) if m else 0.0
values.append(v)
colors.append(COLORS[i % len(COLORS)])
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
ax.set_title(title, fontsize=9)
ax.set_xlabel(key, fontsize=8)
ax.tick_params(axis="y", labelsize=7)
ax.tick_params(axis="x", labelsize=7)
# Value labels on bars
for bar, val in zip(bars, values):
w = bar.get_width()
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
f"{val:.3f}", va="center", ha="left", fontsize=6)
canvas = FigureCanvasAgg(fig)
canvas.draw()
canvas.print_figure(str(output_path), dpi=110)
buf = canvas.buffer_rgba()
w, h = canvas.get_width_height()
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
from PIL import Image
return _pil_to_tensor(Image.fromarray(arr))
class SelvaLoraEvaluator:
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
Generates one audio sample per adapter, computes spectral metrics for each,
and produces a comparison chart. Use this after a sweep to compare candidates
before running the next round of training.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_image")
OUTPUT_TOOLTIPS = (
"Path to eval_summary.json — contains spectral metrics per adapter.",
"Bar chart comparing spectral metrics across all evaluated adapters.",
)
DESCRIPTION = (
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
"from a fixed reference clip, then collects spectral metrics for comparison. "
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"eval_file": ("STRING", {
"default": "eval_batch.json",
"tooltip": (
"Path to the JSON evaluation spec. Relative paths resolve "
"to the ComfyUI output directory. "
"Each adapter entry needs an 'id' and an optional 'path'. "
"Omit 'path' for a no-LoRA baseline."
),
}),
}
}
def run(self, model, eval_file):
# ------------------------------------------------------------------
# 1. Resolve and parse the JSON file
# ------------------------------------------------------------------
eval_path = Path(eval_file.strip())
if not eval_path.is_absolute():
candidate = Path(folder_paths.models_dir) / eval_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / eval_path
eval_path = candidate
if not eval_path.exists():
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
spec = json.loads(eval_path.read_text(encoding="utf-8"))
if "adapters" not in spec or not spec["adapters"]:
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
for i, a in enumerate(spec["adapters"]):
if "id" not in a:
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
if "data_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
if "output_dir" not in spec:
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
name = spec.get("name", eval_path.stem)
data_dir = _resolve_path(spec["data_dir"])
output_dir = _resolve_path(spec["output_dir"])
steps = int(spec.get("steps", 25))
seed = int(spec.get("seed", 42))
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
# ------------------------------------------------------------------
# 2. Prepare dataset (VAE encode once)
# ------------------------------------------------------------------
device = get_device()
dtype = model["dtype"]
dataset = _prepare_dataset(model, data_dir, device)
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
# ------------------------------------------------------------------
# 3. Collect reference metrics for all dataset clips
# ------------------------------------------------------------------
import shutil
npz_files = sorted(data_dir.glob("*.npz"))
ref_dir = output_dir / "reference"
ref_dir.mkdir(exist_ok=True)
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
flush=True)
for npz_path in npz_files:
audio_path = _find_audio(npz_path)
if audio_path is None:
continue
try:
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
ref_wav = ref_wav.unsqueeze(0) # [1, L]
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
shutil.copy2(str(audio_path), str(ref_out))
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
ref_clips.append({
"clip": npz_path.stem,
"wav_path": str(ref_out),
"spectral_metrics": metrics,
})
except Exception as e:
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
# Average reference metrics across all clips
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
print(f"[LoRA Evaluator] Reference avg — "
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
# ------------------------------------------------------------------
# 4. Build summary skeleton
# ------------------------------------------------------------------
summary = {
"name": name,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"data_dir": str(data_dir),
"output_dir": str(output_dir),
"n_clips": len(ref_clips),
"steps": steps,
"seed": seed,
"reference_avg": ref_avg,
"reference_clips": ref_clips,
"adapters": [],
}
summary_path = output_dir / "eval_summary.json"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Per-adapter evaluation loop (all clips)
# ------------------------------------------------------------------
n_clips = len(dataset)
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
for adapter_spec in spec["adapters"]:
adapter_id = adapter_spec["id"]
adapter_path = (adapter_spec.get("path") or "").strip()
safe_id = _safe_stem(adapter_id)
clip_dir = output_dir / safe_id
clip_dir.mkdir(exist_ok=True)
record = {
"id": adapter_id,
"path": adapter_path or None,
"meta": None,
"clips": [],
"avg_metrics": None,
"status": "running",
}
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
try:
with torch.inference_mode(False):
generator = copy.deepcopy(model["generator"])
if adapter_path:
pt_path = Path(adapter_path)
if not pt_path.is_absolute():
pt_path = Path(folder_paths.base_path) / pt_path
if not pt_path.exists():
raise FileNotFoundError(f"Adapter not found: {pt_path}")
ckpt = torch.load(str(pt_path), map_location="cpu",
weights_only=False)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
meta = ckpt.get("meta", {})
else:
state_dict = ckpt
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
dropout = float(meta.get("lora_dropout", 0.0))
use_rslora = meta.get("use_rslora", False)
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
# Always use standard init for loading — PiSSA checkpoints
# include linear.weight (residual) in state_dict
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target), dropout=dropout,
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"apply_lora matched 0 layers (target={target})"
)
load_lora(generator, state_dict)
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
f"(rank={rank}, {n} layers)", flush=True)
else:
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
generator = generator.to(device, dtype)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
clip_metrics_list = []
for clip_idx in range(n_clips):
clip_stem = npz_files[clip_idx].stem
wav, sr = _eval_sample(
generator, feature_utils_orig, dataset,
seq_cfg, device, dtype,
num_steps=steps, seed=seed, clip_idx=clip_idx,
)
if wav is None:
pbar.update(1)
continue
wav_path = clip_dir / f"{clip_stem}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
metrics = _spectral_metrics(wav, sr)
clip_metrics_list.append(metrics)
record["clips"].append({
"clip": clip_stem,
"wav_path": str(wav_path),
"spectral_metrics": metrics,
})
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
pbar.update(1)
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
record["status"] = "completed"
avg = record["avg_metrics"]
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
except Exception as e:
record["status"] = "failed"
record["error"] = str(e)
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
traceback.print_exc()
pbar.update(n_clips - len(record["clips"]))
finally:
try:
del generator
except NameError:
pass
soft_empty_cache()
summary["adapters"].append(record)
_write_summary()
# ------------------------------------------------------------------
# 5. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 6. Comparison chart
# ------------------------------------------------------------------
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
if completed:
ids = ["reference"] + [r["id"] for r in completed]
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
chart_path = output_dir / "metric_comparison.png"
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
else:
from PIL import Image
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
return (str(summary_path), comparison)
-109
View File
@@ -1,109 +0,0 @@
import copy
import torch
import folder_paths
from .utils import SELVA_CATEGORY
from selva_core.model.lora import apply_lora, load_lora
class SelvaLoraLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"adapter_path": ("STRING", {
"default": "",
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
}),
"strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to all LoRA contributions. "
"1.0 = full adapter strength. "
"0.0 = effectively disables the adapter. "
"Values above 1.0 exaggerate the effect.",
}),
},
}
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
"The base model is not modified — a shallow copy of the model bundle is returned."
)
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
if not adapter_path.strip():
raise ValueError("[SelVA LoRA] adapter_path is empty.")
# Resolve path: allow absolute or relative to ComfyUI base
from pathlib import Path
p = Path(adapter_path)
if not p.is_absolute():
p = Path(folder_paths.base_path) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
# Support both raw state_dict and {state_dict, meta} formats
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
meta = checkpoint.get("meta", {})
else:
state_dict = checkpoint
meta = {}
rank = int(meta.get("rank", 16))
alpha = float(meta.get("alpha", float(rank)))
target = list(meta.get("target", ["attn.qkv"]))
init_mode = meta.get("init_mode", "standard")
use_rslora = meta.get("use_rslora", False)
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
f"init={init_mode} rslora={use_rslora} strength={strength}",
flush=True)
# Shallow-copy the model bundle so the original generator is not mutated
patched = {**model}
generator = copy.deepcopy(model["generator"])
# For PiSSA, use standard init (the base weights will be overwritten
# by load_state_dict since the checkpoint includes linear.weight)
n = apply_lora(generator, rank=rank, alpha=alpha,
target_suffixes=tuple(target),
init_mode="standard", use_rslora=use_rslora)
if n == 0:
raise RuntimeError(
f"[SelVA LoRA] No layers matched target={target}. "
"Check that the adapter was trained with the same target suffixes."
)
load_lora(generator, state_dict)
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
norms = [p.norm().item() for name, p in generator.named_parameters()
if "lora_A" in name]
if norms:
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
else:
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
# Apply strength scaling: multiply all lora_B params by strength
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
if strength != 1.0:
with torch.no_grad():
for name, param in generator.named_parameters():
if "lora_B" in name:
param.mul_(strength)
generator.to(model["generator"].parameters().__next__().device)
patched["generator"] = generator
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
return (patched,)
-539
View File
@@ -1,539 +0,0 @@
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "tier1_sweep",
"description": "optional human note",
"data_dir": "dataset/dog_bark",
"output_root": "lora_output/tier1_sweep",
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
"experiments": [
{"id": "baseline", "description": "..."},
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
...
]
}
"""
import copy
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
from PIL import Image, ImageDraw
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
SelvaLoraTrainer,
SkipExperiment,
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
def _get_system_info() -> dict:
"""Collect GPU / torch version info for the summary header."""
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
_PARAM_DEFAULTS = {
"alpha": 0.0,
"target": "attn.qkv",
"batch_size": 4,
"warmup_steps": 100,
"grad_accum": 1,
"save_every": 500,
"resume_path": "",
"seed": 42,
"timestep_mode": "uniform",
"logit_normal_sigma": 1.0,
"curriculum_switch": 0.6,
"lora_dropout": 0.0,
"lora_plus_ratio": 1.0,
"lr_schedule": "constant",
"init_mode": "pissa",
"use_rslora": True,
"latent_mixup_alpha": 0.0,
"latent_noise_sigma": 0.0,
}
# Palette for comparison chart: one color per experiment (cycles if > 8)
_PALETTE = [
(66, 133, 244), # blue
(234, 67, 53), # red
(52, 168, 83), # green
(251, 188, 5), # yellow
(155, 89, 182), # purple
(26, 188, 156), # teal
(230, 126, 34), # orange
(149, 165, 166), # grey
]
def _resolve_path(raw: str) -> Path:
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
"""Merge base defaults + file base + experiment overrides."""
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
# Don't carry id/description into the training params
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
start_step: int, total_steps: int) -> dict:
"""Build a dict of {step: loss} at each save_every boundary.
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
start + (i+1)*log_interval].
"""
result = {}
targets = range(save_every, total_steps + 1, save_every)
for target in targets:
# index of the loss entry nearest to this step
idx = (target - start_step) // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
) -> Image.Image:
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
# Collect all smoothed series
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"log_interval": ed.get("log_interval", 50),
"start_step": ed.get("start_step", 0),
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
# Horizontal grid + y-axis labels
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
# Draw each curve
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
# Axes
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
# Legend (right side)
lx = W - pr + 10
ly = pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
class SelvaLoraScheduler:
"""Runs a sweep of LoRA training experiments defined in a JSON file.
The dataset (VAE encoding + .npz loading) is performed once and shared
across all experiments. Each experiment deep-copies the generator and trains
independently. Results are written to `experiment_summary.json` after every
completed run so partial results are preserved if the sweep is interrupted.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — share this file to compare runs.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
"The dataset is encoded once and reused across all experiments. "
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"models directory; absolute paths are used as-is. "
"See LORA_TRAINING.md for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate the JSON file
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
# Try relative to ComfyUI models dir first, then output dir
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[LoRA Scheduler] {description}", flush=True)
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load + encode dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore the summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = [] # collected for comparison image
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
lh = rec["results"].get("loss_history", [])
all_curve_data.append({
"id": rec["id"],
"loss_history": lh,
"log_interval": rec["results"].get("log_interval", 50),
"start_step": 0,
})
# Restore the original summary, clear completed_at so it gets set again
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaLoraTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
feature_utils_orig = model["feature_utils"]
seq_cfg = model["seq_cfg"]
variant = model["variant"]
mode = model["mode"]
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
# Required training params
steps = int(cfg.get("steps", 2000))
rank = int(cfg.get("rank", 16))
lr = float(cfg.get("lr", 1e-4))
alpha = float(cfg.get("alpha", 0.0))
target = str(cfg.get("target", "attn.qkv"))
batch_size = int(cfg.get("batch_size", 4))
warmup = int(cfg.get("warmup_steps", 100))
grad_accum = int(cfg.get("grad_accum", 1))
save_every = int(cfg.get("save_every", 500))
resume_path = str(cfg.get("resume_path", ""))
seed = int(cfg.get("seed", 42))
ts_mode = str(cfg.get("timestep_mode", "uniform"))
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
curr_switch = float(cfg.get("curriculum_switch", 0.6))
dropout = float(cfg.get("lora_dropout", 0.0))
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
lr_schedule = str(cfg.get("lr_schedule", "constant"))
init_mode = str(cfg.get("init_mode", "pissa"))
use_rslora = bool(cfg.get("use_rslora", True))
alpha_val = alpha if alpha > 0.0 else float(2 * rank)
target_suffixes = tuple(target.strip().split())
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
"batch_size": batch_size, "warmup_steps": warmup,
"grad_accum": grad_accum, "save_every": save_every,
"seed": seed, "target": list(target_suffixes),
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
"curriculum_switch": curr_switch,
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
"lr_schedule": lr_schedule,
"init_mode": init_mode, "use_rslora": use_rslora,
},
"results": {"status": "running"},
"adapter_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, variant, mode,
data_dir, output_dir, steps, rank, lr,
alpha_val, target_suffixes, batch_size, warmup,
grad_accum, save_every, resume_path, seed,
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
lr_schedule, init_mode, use_rslora,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
grad_norm_history = r.get("grad_norm_history", [])
spectral_metrics = r.get("spectral_metrics", {})
run_start_step = r.get("start_step", 0)
smoothed = _smooth_losses(loss_history) if loss_history else []
# Scalar summary metrics
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (
run_start_step + (min_idx + 1) * log_interval
if min_idx is not None else None
)
# Stability: std-dev of raw loss over last 25% of steps
if loss_history:
quarter = max(1, len(loss_history) // 4)
last_q = loss_history[-quarter:]
loss_std_last_quarter = round(float(np.std(last_q)), 6)
else:
loss_std_last_quarter = None
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, run_start_step, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"grad_norm_history": grad_norm_history,
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["adapter_path"] = r["adapter_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
"log_interval": log_interval,
"start_step": 0,
})
except SkipExperiment as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
partial = getattr(e, "partial", {})
lh = partial.get("loss_history", [])
smoothed = _smooth_losses(lh) if lh else []
exp_record["results"] = {
"status": "skipped",
"stopped_at_step": partial.get("stopped_at_step"),
"final_loss": round(smoothed[-1], 6) if smoothed else None,
"loss_history": [round(v, 6) for v in lh],
"grad_norm_history": partial.get("grad_norm_history", []),
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
except Exception as e:
duration = time.monotonic() - t_start
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
# Continue to next experiment rather than aborting the whole sweep
continue
_write_summary()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise summary
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image
# ------------------------------------------------------------------
comparison_img = _draw_comparison_curves(all_curve_data)
comparison_img.save(str(output_root / "loss_comparison.png"))
comparison_tensor = _pil_to_tensor(comparison_img)
return (str(summary_path), comparison_tensor)
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -149,7 +149,7 @@ class SelvaModelLoader:
enable_conditions=True, enable_conditions=True,
mode=mode, mode=mode,
bigvgan_vocoder_ckpt=bigvgan_path, bigvgan_vocoder_ckpt=bigvgan_path,
need_vae_encoder=True, need_vae_encoder=False,
).to(device, dtype).eval() ).to(device, dtype).eval()
if strategy == "offload_to_cpu": if strategy == "offload_to_cpu":
+3 -107
View File
@@ -3,7 +3,6 @@ import comfy.utils
import comfy.model_management import comfy.model_management
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .selva_textual_inversion_trainer import _inject_tokens
class SelvaSampler: class SelvaSampler:
@@ -32,31 +31,9 @@ class SelvaSampler:
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
}, },
"optional": { "optional": {
"steering_vectors": ("STEERING_VECTORS", {
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
"Nudges each DiT block's hidden state toward the extracted pattern.",
}),
"steering_strength": ("FLOAT", {
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
"tooltip": "Scale applied to each steering vector before adding to block output. "
"Start around 0.10.3; higher values risk destabilizing the ODE.",
}),
"normalize": ("BOOLEAN", { "normalize": ("BOOLEAN", {
"default": True, "default": True,
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.", "tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
}),
"target_lufs": ("FLOAT", {
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0,
"tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.",
}),
"textual_inversion": ("TEXTUAL_INVERSION", {
"tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. "
"Injects style tokens into CLIP conditioning without modifying model weights.",
}),
"ti_strength": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
"tooltip": "Blends between original CLIP conditioning (0.0) and full TI injection (1.0). "
"Reduce toward 0.30.5 if TI produces buzz artifacts.",
}), }),
}, },
} }
@@ -68,7 +45,7 @@ class SelvaSampler:
CATEGORY = SELVA_CATEGORY CATEGORY = SELVA_CATEGORY
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance." DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0): def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
import dataclasses import dataclasses
from selva_core.model.flow_matching import FlowMatching from selva_core.model.flow_matching import FlowMatching
@@ -133,19 +110,6 @@ class SelvaSampler:
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \ neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
if negative_prompt.strip() else None if negative_prompt.strip() else None
# Inject textual inversion tokens into CLIP conditioning
if textual_inversion is not None:
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
K = emb.shape[0]
inject_mode = textual_inversion.get("inject_mode", "suffix")
ti_text = _inject_tokens(text_clip, emb, K, inject_mode)
text_clip = torch.lerp(text_clip, ti_text, ti_strength)
if neg_text_clip is not None:
ti_neg = _inject_tokens(neg_text_clip, emb, K, inject_mode)
neg_text_clip = torch.lerp(neg_text_clip, ti_neg, ti_strength)
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode} strength={ti_strength}",
flush=True)
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip) conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
empty_conditions = net_generator.get_empty_conditions( empty_conditions = net_generator.get_empty_conditions(
bs=1, negative_text_features=neg_text_clip bs=1, negative_text_features=neg_text_clip
@@ -159,63 +123,6 @@ class SelvaSampler:
device=gen_device, dtype=dtype, generator=rng, device=gen_device, dtype=dtype, generator=rng,
).to(device) ).to(device)
# Activation steering: apply only during the conditional predict_flow pass
# so steering gets amplified by cfg_strength rather than canceling out.
steering_handles = []
_orig_predict_flow = None
if steering_vectors is not None and steering_strength > 0.0:
vecs = steering_vectors["steering_vectors"]
n_joint = steering_vectors["n_joint"]
# Patch predict_flow to flag which pass is conditional.
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
# identity check tells us which is which.
_is_cond_pass = [False]
_orig_predict_flow = net_generator.predict_flow
def _tracked_predict_flow(latent, t, cond):
_is_cond_pass[0] = (cond is conditions)
return _orig_predict_flow(latent, t, cond)
net_generator.predict_flow = _tracked_predict_flow
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
vec = vec_cpu.to(dev, dt) # [seq, hidden]
def hook(module, input, output):
if not _is_cond_pass[0]:
return # skip unconditional pass; steering amplified by cfg_strength
# Interpolate steering vec to match actual output seq length
# (handles generation at different duration than extraction)
if is_joint:
out_seq = output[0].shape[1]
else:
out_seq = output.shape[1]
v = vec
if v.shape[0] != out_seq:
v = torch.nn.functional.interpolate(
v.T.unsqueeze(0), # [1, hidden, seq_orig]
size=out_seq,
mode="linear",
align_corners=False,
).squeeze(0).T # [seq_new, hidden]
if is_joint:
latent_out = output[0] + strength * v
return (latent_out,) + output[1:]
else:
return output + strength * v
return hook
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
for i, block in enumerate(blocks):
is_joint = i < n_joint
if i < len(vecs):
h = block.register_forward_hook(
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
)
steering_handles.append(h)
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
f"strength={steering_strength} (conditional pass only)", flush=True)
# Flow matching ODE (Euler) # Flow matching ODE (Euler)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps) fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
@@ -232,11 +139,6 @@ class SelvaSampler:
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy " "[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
"to 'offload_to_cpu', using a smaller variant, or reducing duration." "to 'offload_to_cpu', using a smaller variant, or reducing duration."
) )
finally:
if _orig_predict_flow is not None:
net_generator.predict_flow = _orig_predict_flow
for h in steering_handles:
h.remove()
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True) print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
@@ -266,14 +168,8 @@ class SelvaSampler:
audio = audio.mean(dim=1, keepdim=True) # stereo → mono audio = audio.mean(dim=1, keepdim=True) # stereo → mono
if normalize: if normalize:
target_rms = 10 ** (target_lufs / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = audio * (target_rms / rms)
# If RMS normalization pushes peaks into clipping, scale back to
# preserve dynamics rather than hard-clipping (no saturation)
peak = audio.abs().max().clamp(min=1e-8) peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0: audio = (audio / peak).clamp(-1, 1)
audio = audio / peak
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True) print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},) return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
-50
View File
@@ -1,50 +0,0 @@
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaSkipExperiment:
"""Writes skip_current.flag into a sweep output_root.
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
the current experiment and move to the next one. The trainer picks up
the flag within 50 steps (~a few seconds).
"""
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"output_root": ("STRING", {
"default": "",
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("flag_path",)
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
FUNCTION = "skip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
"and move to the next one. Queue this node while the scheduler is running. "
"Partial scalars collected so far are saved in the summary."
)
def skip(self, output_root: str):
p = Path(output_root.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
flag = p / "skip_current.flag"
flag.touch()
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
return (str(flag),)
-70
View File
@@ -1,70 +0,0 @@
"""SelVA Textual Inversion Loader.
Loads a .pt file produced by SelvaTextualInversionTrainer and returns a
TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning.
"""
from pathlib import Path
import torch
import folder_paths
from .utils import SELVA_CATEGORY
class SelvaTextualInversionLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. "
"Relative paths resolve to the ComfyUI output directory.",
}),
},
}
RETURN_TYPES = ("TEXTUAL_INVERSION",)
RETURN_NAMES = ("textual_inversion",)
OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",)
FUNCTION = "load"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. "
"Connect the output to the SelVA Sampler's optional textual_inversion input to guide "
"generation toward the training data style without degrading audio quality."
)
def load(self, path: str) -> tuple:
p = Path(path.strip())
if not p.is_absolute():
p = Path(folder_paths.get_output_directory()) / p
if not p.exists():
raise FileNotFoundError(f"[TI Loader] File not found: {p}")
data = torch.load(str(p), map_location="cpu", weights_only=False)
embeddings = data["embeddings"] # [K, 1024]
n_tokens = int(data.get("n_tokens", embeddings.shape[0]))
print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} "
f"shape={tuple(embeddings.shape)}", flush=True)
if data.get("init_text"):
print(f"[TI Loader] init_text='{data['init_text']}'", flush=True)
if data.get("step"):
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
f"lr={data.get('lr', '?')}", flush=True)
inject_mode = data.get("inject_mode", "suffix")
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
bundle = {
"embeddings": embeddings, # [K, 1024] float32 on CPU
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"path": str(p),
"init_text": data.get("init_text", ""),
}
return (bundle,)
-450
View File
@@ -1,450 +0,0 @@
"""SelVA Textual Inversion Trainer.
Learns K token embedding vectors in CLIP space that guide the base model
to generate audio in the style of the training clips — without modifying
any model weights.
Key difference from LoRA:
- ALL generator parameters are frozen (requires_grad=False)
- Only K×1024 token embeddings receive gradients
- Latents stay on the decoder's natural manifold → no quality degradation
- The learned tokens shift WHICH latents are generated, not HOW
Usage:
1. Train on your .npz audio features
2. Load result with SelVA Textual Inversion Loader
3. Connect to SelVA Sampler optional input
"""
import copy
import random
import traceback
from pathlib import Path
import torch
import torchaudio
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
from selva_core.model.flow_matching import FlowMatching
from .selva_lora_trainer import (
_prepare_dataset,
_eval_sample,
_spectral_metrics,
_save_spectrogram,
_smooth_losses,
_draw_loss_curve,
_pil_to_tensor,
)
# ---------------------------------------------------------------------------
# Eval helper with token injection
# ---------------------------------------------------------------------------
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
n_tokens: int, inject_mode: str) -> torch.Tensor:
"""Build a text_clip tensor with learned tokens injected.
inject_mode:
"suffix" — replace last n_tokens positions (EOS/padding zone)
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
Works for both training (tokens is a Parameter) and eval (tokens is detached).
"""
if inject_mode == "prefix":
bos = text_clip[:, :1, :].detach() # [B, 1, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
else: # suffix (default)
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
return torch.cat([front, toks], dim=1) # [B, 77, D]
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, num_steps=25, seed=42, clip_idx=0):
"""Inference pass with learned tokens injected into text conditioning."""
generator.eval()
try:
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
clip_f = clip_f_cpu.to(device, dtype)
sync_f = sync_f_cpu.to(device, dtype)
text_clip = text_clip_cpu.to(device, dtype).clone()
emb = learned_tokens.detach().to(device, dtype)
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
rng = torch.Generator(device=device).manual_seed(seed)
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
device=device, dtype=dtype, generator=rng)
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
def velocity_fn(t, x):
return generator.forward(x, clip_f, sync_f, text_input,
t.reshape(1).to(device, dtype))
with torch.no_grad():
x1_pred = eval_fm.to_data(velocity_fn, x0)
x1_unnorm = generator.unnormalize(x1_pred)
tod = feature_utils_orig.tod
tod_orig_dev = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils_orig.decode(x1_unnorm)
audio = feature_utils_orig.vocode(spec)
finally:
tod.to(tod_orig_dev)
audio = audio.float().cpu()
if audio.dim() == 2:
audio = audio.unsqueeze(1)
elif audio.dim() == 3 and audio.shape[1] != 1:
audio = audio.mean(dim=1, keepdim=True)
target_rms = 10 ** (-27.0 / 20.0)
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
audio = (audio * (target_rms / rms))
peak = audio.abs().max().clamp(min=1e-8)
if peak > 1.0:
audio = audio / peak
return audio.squeeze(0), seq_cfg.sampling_rate
except Exception as e:
print(f"[TI Trainer] Eval sample failed: {e}", flush=True)
traceback.print_exc()
return None, None
finally:
generator.train()
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTextualInversionTrainer:
"""Learns K CLIP token embeddings that steer SelVA toward a target audio style.
Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor
receives gradients, keeping generated latents on the decoder's natural manifold
and preserving base model audio quality while shifting generation style.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "train"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("embeddings_path", "loss_curve")
OUTPUT_TOOLTIPS = (
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
"Smoothed training loss curve.",
)
DESCRIPTION = (
"Trains K learnable CLIP token embeddings against your audio dataset "
"with all model weights frozen. The tokens are then injected into the "
"sampler to guide generation toward the training data style without "
"degrading audio quality."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"data_dir": ("STRING", {
"default": "",
"tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).",
}),
"output_path": ("STRING", {
"default": "textual_inversion.pt",
"tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.",
}),
"n_tokens": ("INT", {
"default": 4, "min": 1, "max": 16,
"tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.",
}),
"steps": ("INT", {
"default": 3000, "min": 100, "max": 50000,
"tooltip": "Training steps. 3000 is a reasonable starting point.",
}),
"lr": ("FLOAT", {
"default": 2e-4, "min": 1e-5, "max": 1e-1, "step": 1e-5,
"tooltip": "Learning rate. 2e-4 matches the LoRA working regime. Higher LR (1e-3) causes token norm to drift without plateauing on small datasets.",
}),
"batch_size": ("INT", {
"default": 4, "min": 1, "max": 64,
"tooltip": "Clips sampled per training step. Smaller batch (48) gives more diverse gradients and helps token norm saturate rather than drift.",
}),
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
"save_every": ("INT", {
"default": 1000, "min": 100, "max": 10000,
"tooltip": "Save a checkpoint and generate an eval sample every N steps.",
}),
},
"optional": {
"inject_mode": (["suffix", "prefix"], {
"default": "suffix",
"tooltip": (
"Where to inject the learned tokens in the 77-token CLIP sequence. "
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
),
}),
"init_text": ("STRING", {
"default": "",
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
}),
"warmup_steps": ("INT", {
"default": 100, "min": 0, "max": 1000,
"tooltip": "Linear LR warmup steps.",
}),
},
}
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
batch_size, seed, save_every,
inject_mode="suffix", init_text="", warmup_steps=100):
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
# --- Resolve paths ---
data_dir = Path(data_dir.strip())
if not data_dir.is_absolute():
data_dir = Path(folder_paths.models_dir) / data_dir
if not data_dir.exists():
raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}")
out_path = Path(output_path.strip())
if not out_path.is_absolute():
out_path = Path(folder_paths.get_output_directory()) / out_path
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True)
print(f"[TI Trainer] data_dir = {data_dir}", flush=True)
print(f"[TI Trainer] output = {out_path}\n", flush=True)
# --- Load dataset (reuse LoRA trainer helper) ---
dataset = _prepare_dataset(model, data_dir, device)
# Training must run outside inference_mode so autograd works
with torch.inference_mode(False), torch.enable_grad():
r = self._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode,
)
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
return (r["embeddings_path"], _pil_to_tensor(curve_img))
def _train_inner(
self, model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
):
torch.manual_seed(seed)
# --- Generator (frozen) ---
generator = copy.deepcopy(model["generator"]).to(device, dtype)
generator.requires_grad_(False)
generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Init learned tokens ---
# Call encode_text_clip outside the grad context (it has @inference_mode),
# grab values only (no grad needed), then wrap as nn.Parameter.
if init_text.strip():
with torch.no_grad():
init_embed = feature_utils_orig.encode_text_clip([init_text.strip()])
# Positions 1:1+n_tokens — after BOS, before EOS — have actual content
init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float()
if init_vals.shape[0] < n_tokens:
# Prompt was very short; pad remaining with small noise
pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02
init_vals = torch.cat([init_vals, pad], dim=0)
learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype))
print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1{n_tokens})", flush=True)
else:
learned_tokens = torch.nn.Parameter(
torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02
)
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
# --- Measure CLIP token norm from the dataset (content positions 120) ---
# Learned tokens must stay within this range or the model treats them as
# out-of-distribution and produces buzz artifacts instead of style shift.
with torch.no_grad():
sample_norms = []
for item in dataset[:min(len(dataset), 20)]:
tc = item[3].squeeze(0) # [77, 1024]
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
clip_norm_ref = torch.cat(sample_norms).mean().item()
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
f"limit={clip_norm_limit:.4f}", flush=True)
# --- Optimizer + scheduler ---
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
def lr_lambda(s):
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Checkpoint dir ---
ckpt_dir = out_path.parent / out_path.stem
ckpt_dir.mkdir(parents=True, exist_ok=True)
# --- Baseline sample (once, before any training) ---
print(f"[TI Trainer] Generating baseline sample...", flush=True)
baseline_wav, baseline_sr = _eval_sample(
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
)
if baseline_wav is not None:
baseline_path = ckpt_dir / "baseline.wav"
try:
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
except RuntimeError:
import soundfile as sf
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
try:
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
except Exception:
pass
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
# --- Training loop ---
generator.train()
optimizer.zero_grad()
log_interval = 50
pbar = comfy.utils.ProgressBar(steps)
loss_history = []
running_loss = 0.0
print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True)
for step in range(1, steps + 1):
batch = random.choices(dataset, k=batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
x1 = generator.normalize(x1)
t = torch.rand(batch_size, device=device, dtype=dtype)
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
loss = fm.loss(v_pred, x0, x1).mean()
loss.backward()
torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Clamp token norm to CLIP manifold — prevents out-of-distribution
# embeddings that cause buzz artifacts instead of style shift.
with torch.no_grad():
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
scale = (clip_norm_limit / norms).clamp(max=1.0)
learned_tokens.data.mul_(scale)
running_loss += loss.item()
pbar.update(1)
if step % log_interval == 0:
avg = running_loss / log_interval
loss_history.append(round(avg, 6))
running_loss = 0.0
lr_now = scheduler.get_last_lr()[0]
norm = learned_tokens.norm(dim=-1).mean().item()
print(f"[TI Trainer] step {step:5d}/{steps} "
f"loss={avg:.4f} lr={lr_now:.2e} "
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
if step % save_every == 0 or step == steps:
# Save checkpoint
ckpt = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": step,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
torch.save(ckpt, str(ckpt_path))
# Eval sample
wav, sr = _eval_sample_ti(
generator, learned_tokens, n_tokens, inject_mode,
feature_utils_orig, dataset, seq_cfg,
device, dtype, seed=seed,
)
if wav is not None:
wav_path = ckpt_dir / f"step_{step:05d}.wav"
try:
torchaudio.save(str(wav_path), wav, sr)
except RuntimeError:
import soundfile as sf
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
try:
metrics = _spectral_metrics(wav, sr)
_save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
print(f"[TI Trainer] step {step} "
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
f"flatness={metrics['spectral_flatness']:.4f} "
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
except Exception as e:
print(f"[TI Trainer] Spectral/spectrogram failed: {e}", flush=True)
print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True)
# --- Final save ---
final = {
"embeddings": learned_tokens.detach().cpu(),
"n_tokens": n_tokens,
"inject_mode": inject_mode,
"step": steps,
"init_text": init_text,
"lr": lr,
"steps": steps,
"loss_history": loss_history,
}
torch.save(final, str(out_path))
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
soft_empty_cache()
return {
"embeddings_path": str(out_path),
"loss_history": loss_history,
}
-479
View File
@@ -1,479 +0,0 @@
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
Each experiment inherits from a shared `base` config and overrides specific keys.
The dataset is loaded once and reused across all experiments. Results are written
to `experiment_summary.json` (updated after each completed run) and a comparison
loss-curve image showing all runs on the same axes.
JSON format:
{
"name": "ti_sweep_1",
"description": "optional human note",
"data_dir": "dataset/bj_sounds",
"output_root": "ti_output/sweep_1",
"base": {
"n_tokens": 4,
"lr": 1e-3,
"steps": 3000,
"batch_size": 16,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000
},
"experiments": [
{"id": "baseline", "description": "default 4 tokens"},
{"id": "n8_tokens", "n_tokens": 8},
{"id": "lr_5e4", "lr": 5e-4},
{"id": "warm_init", "init_text": "industrial sound design"},
{"id": "n4_more_steps", "steps": 5000}
]
}
"""
import json
import sys
import time
import traceback
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
import comfy.utils
import folder_paths
from .utils import SELVA_CATEGORY, get_device
from .selva_lora_trainer import (
_prepare_dataset,
_smooth_losses,
_pil_to_tensor,
)
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
# ---------------------------------------------------------------------------
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
# ---------------------------------------------------------------------------
def _get_system_info() -> dict:
info: dict = {
"torch_version": torch.__version__,
"cuda_version": torch.version.cuda or "N/A",
"gpu_name": None,
"gpu_vram_gb": None,
}
if torch.cuda.is_available():
try:
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
except Exception:
pass
return info
_PARAM_DEFAULTS = {
"n_tokens": 4,
"lr": 2e-4,
"steps": 3000,
"batch_size": 4,
"warmup_steps": 100,
"seed": 42,
"save_every": 1000,
"init_text": "",
"inject_mode": "suffix",
}
_PALETTE = [
(66, 133, 244),
(234, 67, 53),
(52, 168, 83),
(251, 188, 5),
(155, 89, 182),
(26, 188, 156),
(230, 126, 34),
(149, 165, 166),
]
def _resolve_path(raw: str) -> Path:
p = Path(raw.strip())
unix_style_on_windows = (
sys.platform == "win32" and p.is_absolute() and not p.drive
)
if not p.is_absolute() or unix_style_on_windows:
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
return p
def _merge_config(base: dict, experiment: dict) -> dict:
cfg = dict(_PARAM_DEFAULTS)
cfg.update(base)
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
return cfg
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
total_steps: int) -> dict:
result = {}
for target in range(save_every, total_steps + 1, save_every):
idx = target // log_interval - 1
if 0 <= idx < len(loss_history):
result[str(target)] = round(loss_history[idx], 6)
return result
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
from PIL import Image, ImageDraw
W, H = 900, 420
pl, pr, pt, pb = 75, 160, 30, 50
img = Image.new("RGB", (W, H), (255, 255, 255))
draw = ImageDraw.Draw(img)
pw = W - pl - pr
ph = H - pt - pb
series = []
for i, ed in enumerate(experiments_data):
lh = ed.get("loss_history") or []
if len(lh) < 2:
continue
sm = _smooth_losses(lh)
series.append({
"id": ed["id"],
"smoothed": sm,
"color": _PALETTE[i % len(_PALETTE)],
})
if not series:
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
return img
all_vals = [v for s in series for v in s["smoothed"]]
lo, hi = min(all_vals), max(all_vals)
if hi == lo:
hi = lo + 1e-6
rng = hi - lo
for i in range(5):
y = pt + int(i * ph / 4)
val = hi - i * rng / 4
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
for s in series:
n = len(s["smoothed"])
pts = []
for j, v in enumerate(s["smoothed"]):
x = pl + int(j * pw / max(n - 1, 1))
y = pt + int((1.0 - (v - lo) / rng) * ph)
pts.append((x, y))
draw.line(pts, fill=s["color"], width=2)
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
lx, ly = W - pr + 10, pt
for s in series:
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
ly += 20
return img
# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------
class SelvaTiScheduler:
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
The dataset is loaded once and reused. Each experiment calls
SelvaTextualInversionTrainer._train_inner() with its own config.
Results are written to experiment_summary.json after every completed run.
"""
OUTPUT_NODE = True
CATEGORY = SELVA_CATEGORY
FUNCTION = "run"
RETURN_TYPES = ("STRING", "IMAGE")
RETURN_NAMES = ("summary_path", "comparison_curves")
OUTPUT_TOOLTIPS = (
"Path to experiment_summary.json — compare runs across sweeps.",
"All smoothed loss curves overlaid on the same axes.",
)
DESCRIPTION = (
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
"The dataset is encoded once and reused. Results (loss, config, embeddings "
"paths) are collected in experiment_summary.json after each run."
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"experiments_file": ("STRING", {
"default": "ti_experiments.json",
"tooltip": (
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
"output directory. See node description for the file format."
),
}),
}
}
def run(self, model, experiments_file):
# ------------------------------------------------------------------
# 1. Read + validate JSON
# ------------------------------------------------------------------
exp_path = Path(experiments_file.strip())
if not exp_path.is_absolute():
candidate = Path(folder_paths.models_dir) / exp_path
if not candidate.exists():
candidate = Path(folder_paths.get_output_directory()) / exp_path
exp_path = candidate
if not exp_path.exists():
raise FileNotFoundError(
f"[TI Scheduler] Experiment file not found: {exp_path}"
)
spec = json.loads(exp_path.read_text(encoding="utf-8"))
if "experiments" not in spec or not spec["experiments"]:
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
for i, exp in enumerate(spec["experiments"]):
if "id" not in exp:
raise ValueError(
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
)
sweep_name = spec.get("name", exp_path.stem)
description = spec.get("description", "")
base_cfg = spec.get("base", {})
# ------------------------------------------------------------------
# 2. Resolve data_dir and output_root
# ------------------------------------------------------------------
if "data_dir" not in spec:
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
data_dir = _resolve_path(spec["data_dir"])
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
output_root.mkdir(parents=True, exist_ok=True)
device = get_device()
dtype = model["dtype"]
mode = model["mode"]
seq_cfg = model["seq_cfg"]
feature_utils_orig = model["feature_utils"]
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
f"{len(spec['experiments'])} experiment(s)", flush=True)
if description:
print(f"[TI Scheduler] {description}", flush=True)
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
# ------------------------------------------------------------------
# 3. Load dataset once
# ------------------------------------------------------------------
n_clips = len(list(data_dir.glob("*.npz")))
dataset = _prepare_dataset(model, data_dir, device)
# ------------------------------------------------------------------
# 4. Build or restore summary (resume-aware)
# ------------------------------------------------------------------
summary_path = output_root / "experiment_summary.json"
completed_ids = set()
all_curve_data = []
if summary_path.exists():
try:
existing = json.loads(summary_path.read_text(encoding="utf-8"))
for rec in existing.get("experiments", []):
if rec.get("results", {}).get("status") == "completed":
completed_ids.add(rec["id"])
all_curve_data.append({
"id": rec["id"],
"loss_history": rec["results"].get("loss_history", []),
})
summary = existing
summary["completed_at"] = None
if completed_ids:
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
except Exception as e:
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
flush=True)
completed_ids = set()
all_curve_data = []
summary = None
if not completed_ids:
summary = {
"sweep_name": sweep_name,
"description": description,
"sweep_file": str(exp_path),
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
"system": _get_system_info(),
"data_dir": str(data_dir),
"n_clips": n_clips,
"experiments": [],
}
comparison_img_path = output_root / "loss_comparison.png"
def _write_summary():
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
def _save_comparison():
try:
img = _draw_comparison_curves(all_curve_data)
img.save(str(comparison_img_path))
except Exception as e:
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
_write_summary()
# ------------------------------------------------------------------
# 5. Run each experiment
# ------------------------------------------------------------------
trainer = SelvaTextualInversionTrainer()
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
log_interval = 50 # matches _train_inner
for exp in spec["experiments"]:
exp_id = exp["id"]
exp_desc = exp.get("description", "")
if exp_id in completed_ids:
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
pbar_outer.update(1)
continue
cfg = _merge_config(base_cfg, exp)
n_tokens = int(cfg["n_tokens"])
lr = float(cfg["lr"])
steps = int(cfg["steps"])
batch_size = int(cfg["batch_size"])
warmup = int(cfg["warmup_steps"])
seed = int(cfg["seed"])
save_every = int(cfg["save_every"])
init_text = str(cfg["init_text"])
inject_mode = str(cfg["inject_mode"])
output_dir = output_root / exp_id
output_dir.mkdir(parents=True, exist_ok=True)
out_path = output_dir / "embeddings.pt"
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
if exp_desc:
print(f"[TI Scheduler] {exp_desc}", flush=True)
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
f"batch_size={batch_size} warmup={warmup} seed={seed} "
f"inject_mode={inject_mode}", flush=True)
if init_text:
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
exp_record = {
"id": exp_id,
"description": exp_desc,
"config": {
"n_tokens": n_tokens,
"lr": lr,
"steps": steps,
"batch_size": batch_size,
"warmup_steps": warmup,
"seed": seed,
"save_every": save_every,
"init_text": init_text,
"inject_mode": inject_mode,
},
"results": {"status": "running"},
"embeddings_path": None,
"output_dir": str(output_dir),
}
summary["experiments"].append(exp_record)
_write_summary()
t_start = time.monotonic()
try:
with torch.inference_mode(False), torch.enable_grad():
r = trainer._train_inner(
model, dataset, feature_utils_orig, seq_cfg,
device, dtype, mode,
data_dir, out_path,
n_tokens, steps, lr, batch_size,
warmup, seed, save_every, init_text, inject_mode,
)
duration = time.monotonic() - t_start
loss_history = r["loss_history"]
smoothed = _smooth_losses(loss_history) if loss_history else []
final_loss = round(smoothed[-1], 6) if smoothed else None
min_loss = round(min(smoothed), 6) if smoothed else None
min_idx = smoothed.index(min(smoothed)) if smoothed else None
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
loss_std_last_quarter = None
if loss_history:
quarter = max(1, len(loss_history) // 4)
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
exp_record["results"] = {
"status": "completed",
"final_loss": final_loss,
"min_loss": min_loss,
"min_loss_step": min_loss_step,
"loss_std_last_quarter": loss_std_last_quarter,
"loss_at_steps": _loss_at_steps(
loss_history, log_interval, save_every, steps
),
"loss_history": [round(v, 6) for v in loss_history],
"log_interval": log_interval,
"duration_seconds": round(duration, 1),
}
exp_record["embeddings_path"] = r["embeddings_path"]
all_curve_data.append({
"id": exp_id,
"loss_history": loss_history,
})
except Exception as e:
duration = time.monotonic() - t_start
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
traceback.print_exc()
exp_record["results"] = {
"status": "failed",
"error": str(e),
"duration_seconds": round(duration, 1),
}
_write_summary()
pbar_outer.update(1)
continue
_write_summary()
_save_comparison()
pbar_outer.update(1)
# ------------------------------------------------------------------
# 6. Finalise
# ------------------------------------------------------------------
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
_write_summary()
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
# ------------------------------------------------------------------
# 7. Comparison image (final update, then return to ComfyUI)
# ------------------------------------------------------------------
_save_comparison()
comparison_img = _draw_comparison_curves(all_curve_data)
return (str(summary_path), _pil_to_tensor(comparison_img))
-157
View File
@@ -1,157 +0,0 @@
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
Useful for diagnosing codec reconstruction quality: if the output sounds
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
not the diffusion model or LoRA.
"""
import torch
import torch.nn.functional as F
import torchaudio
from pathlib import Path
import folder_paths
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
class SelvaVaeRoundtrip:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("SELVA_MODEL",),
"audio": ("AUDIO",),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio_reconstructed",)
OUTPUT_TOOLTIPS = (
"Audio after VAE encode → decode roundtrip. "
"Compare to the input to hear codec reconstruction quality.",
)
FUNCTION = "roundtrip"
CATEGORY = SELVA_CATEGORY
DESCRIPTION = (
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
"Use this to isolate codec reconstruction quality from generation quality. "
"If the output sounds degraded compared to the input, the VAE/DAC is the "
"bottleneck — not the model or LoRA."
)
def roundtrip(self, model, audio):
from selva_core.model.utils.features_utils import FeaturesUtils
mode = model["mode"]
seq_cfg = model["seq_cfg"]
dtype = model["dtype"]
device = get_device()
generator = model["generator"]
feature_utils = model["feature_utils"]
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
vae_path = _SELVA_DIR / "ext" / vae_name
if not vae_path.exists():
raise FileNotFoundError(
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
"Run SelVA Model Loader first to auto-download weights."
)
# Load encoder only — decoder/vocoder come from model["feature_utils"]
# to mirror exactly what the sampler uses.
# AutoEncoderModule requires vocoder_ckpt_path even when only encoding,
# so pass the BigVGAN path (weights won't actually be used for decode here).
bigvgan_path = _SELVA_DIR / "ext" / "best_netG.pt"
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
vae_enc = FeaturesUtils(
tod_vae_ckpt=str(vae_path),
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
bigvgan_vocoder_ckpt=str(bigvgan_path) if bigvgan_path.exists() else None,
).to(device).eval()
try:
# Prepare input audio
waveform = audio["waveform"] # [1, C, L]
sr_in = audio["sample_rate"]
wav = waveform[0].mean(0) # mono [L]
if sr_in != seq_cfg.sampling_rate:
wav = torchaudio.functional.resample(
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
).squeeze(0)
print(f"[VAE Roundtrip] Resampled {sr_in}{seq_cfg.sampling_rate} Hz",
flush=True)
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
if wav.shape[0] > target_len:
wav = wav[:target_len]
elif wav.shape[0] < target_len:
wav = F.pad(wav, (0, target_len - wav.shape[0]))
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
with torch.no_grad():
# Encode: audio → raw latent [1, latent_dim, T]
dist = vae_enc.encode_audio(wav_b)
latent = dist.mode().clone()
# Trim/pad to exact model sequence length (same as _prepare_dataset)
tgt = seq_cfg.latent_seq_len
if latent.shape[2] < tgt:
latent = F.pad(latent, (0, tgt - latent.shape[2]))
elif latent.shape[2] > tgt:
latent = latent[:, :, :tgt]
# To [B, T, latent_dim] — layout the generator uses
latent_t = latent.transpose(1, 2).to(dtype)
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
flush=True)
# Normalize → unnormalize mirrors the training/inference pipeline:
# training normalizes encoded latents; sampler unnormalizes before decode.
# This ensures the latent is in the same space the decoder expects.
latent_norm = generator.normalize(latent_t.clone())
latent_unnorm = generator.unnormalize(latent_norm)
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
flush=True)
# Decode using model's feature_utils — same path as the sampler
tod = feature_utils.tod
tod_orig_device = next(tod.parameters()).device
tod.to(device)
try:
spec = feature_utils.decode(latent_unnorm)
out = feature_utils.vocode(spec)
finally:
tod.to(tod_orig_device)
out = out.float().cpu()
if out.dim() == 1:
out = out.unsqueeze(0).unsqueeze(0)
elif out.dim() == 2:
out = out.unsqueeze(1)
elif out.dim() == 3 and out.shape[1] != 1:
out = out.mean(dim=1, keepdim=True)
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
target_rms = 10 ** (-27.0 / 20.0)
out = out * (target_rms / rms)
out = out.clamp(-1.0, 1.0)
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
flush=True)
finally:
del vae_enc
soft_empty_cache()
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
-309
View File
@@ -1,309 +0,0 @@
"""
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
Supports two initialization modes:
- **standard**: Kaiming-uniform A, zero B (classic LoRA).
- **pissa**: A and B from the top-r SVD of the pretrained weight.
Starts on-manifold, eliminates intruder dimensions at init
(arXiv:2404.02948, NeurIPS 2024 Spotlight).
Supports two scaling modes:
- **standard**: alpha / rank
- **rslora**: alpha / sqrt(rank) — rank-stabilized scaling that prevents
gradient collapse at high ranks (arXiv:2312.03732).
Usage:
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
n = apply_lora(net_generator, rank=16, alpha=16.0)
print(f"Wrapped {n} linear layers with LoRA")
# ... train only LoRA params ...
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
# Later, at inference:
apply_lora(net_generator, rank=16, alpha=16.0)
load_lora(net_generator, torch.load("adapter.pt"))
"""
import math
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
Output: base(x) + (dropout(x) @ A.T @ B.T) * scale
Standard init: A is Kaiming uniform, B is zero → adapter starts at zero.
PiSSA init: A and B from top-r SVD of pretrained weight → adapter starts
at the principal components, base weight stores the residual.
"""
def __init__(self, linear: nn.Linear, rank: int, alpha: float,
dropout: float = 0.0, init_mode: str = "standard",
use_rslora: bool = False):
super().__init__()
in_f = linear.in_features
out_f = linear.out_features
self.linear = linear
linear.weight.requires_grad_(False)
if linear.bias is not None:
linear.bias.requires_grad_(False)
ref_dtype = linear.weight.dtype
ref_device = linear.weight.device
if use_rslora:
self.scale = alpha / math.sqrt(rank)
else:
self.scale = alpha / rank
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
if init_mode == "pissa":
# PiSSA: init from top-r SVD of pretrained weight.
# SVD in float32 for numerical stability, then cast back.
W = linear.weight.data.float() # [out_f, in_f]
U, S, Vt = torch.linalg.svd(W, full_matrices=False)
sqrt_S = S[:rank].sqrt()
# A: [rank, in_f], B: [out_f, rank]
A_init = sqrt_S.unsqueeze(1) * Vt[:rank, :]
B_init = U[:, :rank] * sqrt_S.unsqueeze(0)
# Residual: W_res = W - B_init @ A_init * scale
# so that base(x) + LoRA(x) = W_res@x + (B@A)*scale@x = W@x at init
linear.weight.data = (W - B_init @ A_init * self.scale).to(ref_dtype)
self.lora_A = nn.Parameter(A_init.to(dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(B_init.to(dtype=ref_dtype, device=ref_device))
else:
# Standard LoRA: Kaiming A, zero B → starts at identity
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
def extra_repr(self) -> str:
rank = self.lora_A.shape[0]
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
def apply_lora(
model: nn.Module,
rank: int = 16,
alpha: float = None,
target_suffixes: tuple = ("attn.qkv",),
dropout: float = 0.0,
init_mode: str = "standard",
use_rslora: bool = False,
) -> int:
"""Replace matching nn.Linear layers with LoRALinear in-place.
Args:
model: The module to modify (typically net_generator).
rank: LoRA rank.
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
target_suffixes: Tuple of module name suffixes to wrap. Default is
("attn.qkv",) which targets all SelfAttention QKV
projections in the MM-DiT generator.
Add "linear1" to also wrap post-attention output projections.
dropout: Dropout probability on the LoRA path (not the base linear).
0.050.1 helps regularize on small datasets.
Must be 0 when using PiSSA (principal components shouldn't be dropped).
init_mode: "standard" (Kaiming/zero) or "pissa" (SVD-based).
use_rslora: If True, scale by alpha/sqrt(rank) instead of alpha/rank.
Returns:
Number of linear layers wrapped.
"""
if alpha is None:
alpha = float(rank)
if init_mode == "pissa" and dropout > 0.0:
print("[LoRA] Warning: dropout forced to 0 for PiSSA init "
"(principal components should not be dropped).")
dropout = 0.0
count = 0
for name, module in list(model.named_modules()):
if not any(name.endswith(s) for s in target_suffixes):
continue
if not isinstance(module, nn.Linear):
continue
parts = name.split(".")
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], LoRALinear(
module, rank, alpha, dropout=dropout,
init_mode=init_mode, use_rslora=use_rslora,
))
count += 1
return count
def get_lora_state_dict(model: nn.Module) -> dict:
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
def get_lora_and_base_state_dict(model: nn.Module) -> dict:
"""Return state dict with LoRA params AND base linear weights.
Needed for PiSSA checkpoints where the base weight stores the residual
(W - top_r(W)*scale), not the original pretrained weight.
"""
result = {}
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
prefix = name + "."
result[prefix + "lora_A"] = module.lora_A.data
result[prefix + "lora_B"] = module.lora_B.data
result[prefix + "linear.weight"] = module.linear.weight.data
if module.linear.bias is not None:
result[prefix + "linear.bias"] = module.linear.bias.data
return result
def spectral_surgery(
model: nn.Module,
calibration_fn,
n_calibration: int = 128,
policy: str = "smooth_abs",
):
"""Post-training Spectral Surgery: reweight LoRA singular values to suppress
intruder dimensions and amplify useful components (arXiv:2603.03995).
Args:
model: Model with LoRA applied.
calibration_fn: Callable that takes (model, step_idx) and runs one forward+backward
pass on a calibration sample. Must call loss.backward().
n_calibration: Number of calibration samples to average gradients over.
policy: Reweighting policy: "smooth_abs" (recommended), "hard" (binary).
Modifies LoRA A and B in-place. Returns number of layers processed.
"""
model.eval()
lora_layers = [(name, mod) for name, mod in model.named_modules()
if isinstance(mod, LoRALinear)]
if not lora_layers:
return 0
# Accumulate per-layer gradient sensitivity: g_k = u_k^T * (dL/dΔW) * v_k
sensitivities = {}
for name, mod in lora_layers:
sensitivities[name] = None
for step in range(n_calibration):
model.zero_grad()
# Enable grad temporarily on LoRA params
for _, mod in lora_layers:
mod.lora_A.requires_grad_(True)
mod.lora_B.requires_grad_(True)
calibration_fn(model, step)
for name, mod in lora_layers:
A = mod.lora_A.data.float() # [rank, in_f]
B = mod.lora_B.data.float() # [out_f, rank]
# ΔW = B @ A * scale → gradient dL/dΔW ≈ (dL/dB @ A + B^T @ dL/dA) / 2
# Per-component sensitivity: project onto SVD directions
delta_W = (B @ A * mod.scale).detach()
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
r = A.shape[0]
U_r, S_r, Vt_r = U[:, :r], S[:r], Vt[:r, :]
# Compute sensitivity from LoRA gradients
if mod.lora_A.grad is not None and mod.lora_B.grad is not None:
grad_A = mod.lora_A.grad.float() # [rank, in_f]
grad_B = mod.lora_B.grad.float() # [out_f, rank]
# dL/d(ΔW) ≈ grad_B @ A + B^T @ grad_A (chain rule through B@A)
grad_dW = grad_B @ A + B.T @ grad_A # approximate
# Per-component: g_k = u_k^T @ grad_dW @ v_k
g = torch.einsum("ik,ij,jk->k", U_r, grad_dW, Vt_r.T) # [r]
else:
g = torch.zeros(r, device=A.device)
if sensitivities[name] is None:
sensitivities[name] = g
else:
sensitivities[name] += g
# Disable grad again
for _, mod in lora_layers:
mod.lora_A.requires_grad_(False)
mod.lora_B.requires_grad_(False)
# Apply reweighting per layer
count = 0
for name, mod in lora_layers:
g = sensitivities[name] / n_calibration
A = mod.lora_A.data.float()
B = mod.lora_B.data.float()
delta_W = B @ A * mod.scale
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
r = A.shape[0]
S_r = S[:r]
if policy == "hard":
# Keep components with positive sensitivity, zero out negative
mask = (g > 0).float()
else:
# smooth_abs: sigmoid-weighted by sensitivity magnitude
# Normalize g to [-1, 1] range, apply sigmoid
g_norm = g / (g.abs().max() + 1e-8)
mask = torch.sigmoid(5.0 * g_norm) # steep sigmoid
# L1 norm preservation: scale mask so total nuclear norm is preserved
mask = mask * (S_r.sum() / (mask * S_r).sum().clamp(min=1e-8))
# Reconstruct: ΔW' = U_r @ diag(mask * S_r) @ Vt_r
S_new = mask * S_r
delta_W_new = U[:, :r] @ torch.diag(S_new) @ Vt[:r, :]
# Factor back into B' @ A' * scale: use SVD of ΔW'/scale
dW_unscaled = delta_W_new / mod.scale
U2, S2, Vt2 = torch.linalg.svd(dW_unscaled, full_matrices=False)
sqrt_S2 = S2[:r].sqrt()
A_new = sqrt_S2.unsqueeze(1) * Vt2[:r, :]
B_new = U2[:, :r] * sqrt_S2.unsqueeze(0)
ref_dtype = mod.lora_A.dtype
mod.lora_A.data = A_new.to(ref_dtype)
mod.lora_B.data = B_new.to(ref_dtype)
count += 1
kept = (mask > 0.5).sum().item()
print(f"[Spectral Surgery] {name}: kept {kept}/{r} components, "
f"sensitivity range [{g.min():.3f}, {g.max():.3f}]", flush=True)
return count
def load_lora(model: nn.Module, state_dict: dict) -> None:
"""Load LoRA weights into a model that has already had apply_lora() called.
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
parameters are not modified.
"""
missing, unexpected = model.load_state_dict(state_dict, strict=False)
bad = [k for k in unexpected if "lora_" not in k]
if bad:
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
lora_missing = [k for k in missing if "lora_" in k]
if lora_missing:
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
-465
View File
@@ -1,465 +0,0 @@
#!/usr/bin/env python3
"""
LoRA fine-tuning for SelVA / MMAudio generator.
Teaches the model new or partially-known sound classes from custom video+audio pairs.
Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model).
Data layout:
data/my_sound/
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
clip01.wav # paired clean audio (same filename stem, any format)
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
If prompts.txt is absent, the prompt embedded in each .npz is used.
If the .npz has no embedded prompt, the directory name is used as fallback.
Usage:
python train_lora.py \\
--data_dir data/my_sound \\
--output_dir lora_output \\
--variant large_44k \\
--selva_dir /path/to/ComfyUI/models/selva \\
--rank 16 --steps 2000 --lr 1e-4
"""
import argparse
import os
import sys
import random
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import open_clip
from open_clip import create_model_from_pretrained
sys.path.insert(0, os.path.dirname(__file__))
from selva_core.model.networks_generator import get_my_mmaudio
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
from selva_core.model.flow_matching import FlowMatching
from selva_core.model.lora import apply_lora, get_lora_state_dict
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_VARIANTS = {
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
}
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
# ---------------------------------------------------------------------------
# Data helpers
# ---------------------------------------------------------------------------
def load_prompts(data_dir: Path) -> dict:
"""Load filename → prompt overrides from prompts.txt."""
p = data_dir / "prompts.txt"
if not p.exists():
return {}
mapping = {}
for line in p.read_text(encoding="utf-8").splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if ":" in line:
fname, prompt = line.split(":", 1)
mapping[fname.strip()] = prompt.strip()
return mapping
def find_audio_for_npz(npz_path: Path) -> Path | None:
"""Find a paired audio file with the same stem as the .npz."""
for ext in _AUDIO_EXTS:
candidate = npz_path.with_suffix(ext)
if candidate.exists():
return candidate
return None
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
waveform, sr = torchaudio.load(str(path))
# Stereo → mono
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
waveform = waveform.squeeze(0).float()
# Resample
if sr != target_sr:
waveform = torchaudio.functional.resample(
waveform.unsqueeze(0), sr, target_sr
).squeeze(0)
target_len = int(duration * target_sr)
if waveform.shape[0] >= target_len:
return waveform[:target_len]
return F.pad(waveform, (0, target_len - waveform.shape[0]))
def load_npz(path: Path) -> dict:
"""Load a feature bundle produced by SelvaFeatureExtractor."""
data = np.load(str(path), allow_pickle=False)
bundle = {
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
}
if "prompt" in data:
bundle["prompt"] = str(data["prompt"])
if "variant" in data:
bundle["variant"] = str(data["variant"])
return bundle
# ---------------------------------------------------------------------------
# Feature extraction (audio + text only — visual features come from .npz)
# ---------------------------------------------------------------------------
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
tokens = tokenizer(text).to(device)
with torch.inference_mode():
return clip_model.encode_text(tokens, normalize=True)
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
"""Encode a waveform to the generator's latent space via the VAE.
encode_audio is @inference_mode — .clone() is required before the autograd path.
"""
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
dist = feature_utils.encode_audio(audio_b)
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
parser.add_argument("--output_dir", default="lora_output")
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)")
parser.add_argument("--target", nargs="+", default=["attn.qkv"],
help="Module name suffixes to wrap with LoRA. Also try 'linear1'.")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--steps", type=int, default=2000)
parser.add_argument("--warmup_steps",type=int, default=100)
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
parser.add_argument("--save_every", type=int, default=500)
parser.add_argument("--resume", default=None,
help="Path to a step checkpoint (.pt) to resume training from.")
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
help="Spread of logit-normal distribution.")
parser.add_argument("--curriculum_switch", type=float, default=0.6,
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
parser.add_argument("--lora_dropout", type=float, default=0.0,
help="Dropout on the LoRA path only. 0.050.1 helps on small datasets.")
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
args = parser.parse_args()
torch.manual_seed(args.seed)
random.seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
print("[LoRA] bf16 not supported on this GPU — falling back to fp16")
args.precision = "fp16"
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)
selva_dir = Path(args.selva_dir)
output_dir.mkdir(parents=True, exist_ok=True)
gen_filename, mode = _VARIANTS[args.variant]
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
duration = seq_cfg.duration
sample_rate = seq_cfg.sampling_rate
# --- Weight paths ---
def w(name): return str(selva_dir / name)
def wext(name): return str(selva_dir / "ext" / name)
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
gen_weight = w(gen_filename)
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
if not Path(path).exists():
print(f"[LoRA] Missing weight: {path} ({label})")
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
sys.exit(1)
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
print("[LoRA] Loading CLIP text encoder...")
clip_model = create_model_from_pretrained(
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
).to(device, dtype).eval()
clip_model = patch_clip(clip_model)
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
print("[LoRA] Loading VAE encoder...")
feature_utils = FeaturesUtils(
tod_vae_ckpt=vae_weight,
enable_conditions=False,
mode=mode,
need_vae_encoder=True,
).to(device, dtype).eval()
# --- Load generator ---
print(f"[LoRA] Loading generator ({args.variant})...")
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
net_generator.load_weights(
torch.load(gen_weight, map_location="cpu", weights_only=False)
)
# --- Apply LoRA ---
n_lora = apply_lora(
net_generator,
rank=args.rank,
alpha=args.alpha,
target_suffixes=tuple(args.target),
dropout=args.lora_dropout,
)
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
if n_lora == 0:
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
sys.exit(1)
# Freeze everything except LoRA params
for name, p in net_generator.named_parameters():
p.requires_grad_("lora_" in name)
trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad)
total = sum(p.numel() for p in net_generator.parameters())
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
f"({100 * trainable / total:.2f}%)")
net_generator.update_seq_lengths(
latent_seq_len=seq_cfg.latent_seq_len,
clip_seq_len=seq_cfg.clip_seq_len,
sync_seq_len=seq_cfg.sync_seq_len,
)
# --- Dataset ---
npz_files = sorted(data_dir.glob("*.npz"))
if not npz_files:
print(f"[LoRA] No .npz files found in {data_dir}")
sys.exit(1)
prompt_map = load_prompts(data_dir)
default_prompt = data_dir.name
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
dataset = []
for npz_path in npz_files:
audio_path = find_audio_for_npz(npz_path)
if audio_path is None:
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
continue
bundle = load_npz(npz_path)
# Prompt priority: prompts.txt override > embedded in .npz > directory name
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
try:
audio = load_audio(audio_path, sample_rate, duration)
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
tgt = seq_cfg.latent_seq_len
if x1.shape[1] < tgt:
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
elif x1.shape[1] > tgt:
x1 = x1[:, :tgt, :]
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
# have fewer frames and would cause stack() to fail during batching
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
c_tgt = seq_cfg.clip_seq_len
if clip_f.shape[1] < c_tgt:
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
elif clip_f.shape[1] > c_tgt:
clip_f = clip_f[:, :c_tgt, :]
sync_f = bundle["sync_features"] # [1, N_sync, 768]
s_tgt = seq_cfg.sync_seq_len
if sync_f.shape[1] < s_tgt:
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
elif sync_f.shape[1] > s_tgt:
sync_f = sync_f[:, :s_tgt, :]
dataset.append((x1, clip_f, sync_f, text_clip))
except Exception as e:
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
if not dataset:
print("[LoRA] No clips could be loaded.")
sys.exit(1)
print(f"[LoRA] {len(dataset)} clip(s) ready.")
# --- Optimizer + LR scheduler ---
# LoRA+: separate param groups for A and B with different LRs.
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": lora_A_params, "lr": args.lr},
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
], weight_decay=1e-2)
if args.lora_plus_ratio != 1.0:
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
def lr_lambda(step):
if step < args.warmup_steps:
return step / max(1, args.warmup_steps)
return 1.0
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
# --- Resume ---
start_step = 0
if args.resume:
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
if "step" not in ckpt:
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
sys.exit(1)
start_step = ckpt["step"]
if start_step >= args.steps:
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
sys.exit(0)
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
optimizer.load_state_dict(ckpt["optimizer"])
scheduler.load_state_dict(ckpt["scheduler"])
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step}{args.steps})")
# --- Training loop ---
net_generator.train()
optimizer.zero_grad()
remaining = args.steps - start_step
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1}{args.steps}), "
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
_curriculum_switched = False
total_loss = 0.0
for step in range(start_step + 1, args.steps + 1):
batch = random.choices(dataset, k=args.batch_size)
x1_list, clip_list, sync_list, text_list = zip(*batch)
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
net_generator.normalize(x1)
if args.timestep_mode == "logit_normal" or (
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
):
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
t = torch.sigmoid(u)
else:
t = torch.rand(args.batch_size, device=device, dtype=dtype)
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
_curriculum_switched = True
x0 = torch.randn_like(x1)
xt = fm.get_conditional_flow(x0, x1, t)
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
loss.backward()
total_loss += loss.item() * args.grad_accum
if step % args.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 50 == 0:
avg = total_loss / 50
lr_now = scheduler.get_last_lr()[0]
print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}")
total_loss = 0.0
if step % args.save_every == 0 or step == args.steps:
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
torch.save({
"state_dict": get_lora_state_dict(net_generator),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"step": step,
"meta": {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
},
}, ckpt_path)
print(f"[LoRA] Saved {ckpt_path}")
# Save final adapter with embedded metadata
# Increment filename if a previous final already exists (resume case)
final = output_dir / "adapter_final.pt"
if final.exists():
i = 1
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
i += 1
final = output_dir / f"adapter_final_{i:03d}.pt"
meta = {
"variant": args.variant,
"rank": args.rank,
"alpha": args.alpha if args.alpha is not None else float(args.rank),
"target": args.target,
"steps": args.steps,
"timestep_mode": args.timestep_mode,
"logit_normal_sigma": args.logit_normal_sigma,
"curriculum_switch": args.curriculum_switch,
"lora_dropout": args.lora_dropout,
"lora_plus_ratio": args.lora_plus_ratio,
}
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print(f"\n[LoRA] Training complete. Adapter saved to {final}")
if __name__ == "__main__":
main()