Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c9550ce693 | |||
| f3cabcad90 |
@@ -1,459 +0,0 @@
|
|||||||
# LoRA Training for SelVA
|
|
||||||
|
|
||||||
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Training is split into two steps:
|
|
||||||
|
|
||||||
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
|
|
||||||
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
|
|
||||||
|
|
||||||
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 3–4 GB of VRAM.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
Same environment as SelVA inference. Additional Python packages:
|
|
||||||
|
|
||||||
```
|
|
||||||
torchaudio
|
|
||||||
soundfile
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — Prepare the dataset
|
|
||||||
|
|
||||||
### 1.1 Video format
|
|
||||||
|
|
||||||
The feature extractor accepts any input but internally resamples frames to fixed square resolutions (384×384 for CLIP, 224×224 for Synchformer). Both encoders were trained on standard video datasets — predominantly landscape footage. This has two practical implications:
|
|
||||||
|
|
||||||
**Aspect ratio** — use **16:9 landscape** whenever possible. Portrait clips (9:16) are mechanically supported but the bicubic stretch into square distorts the image relative to the encoders' training distribution, which can degrade sync feature quality. If your source is portrait, center-crop to square before extraction. Square (1:1) is also fine.
|
|
||||||
|
|
||||||
**Resolution** — anything ≥ 480p is sufficient. The extractor downscales to 384px and 224px regardless of source resolution; higher resolution adds no benefit.
|
|
||||||
|
|
||||||
**Frame rate** — any. Connect `VHS_VIDEOINFO` from VHS LoadVideo to the feature extractor so fps is read automatically from the file instead of being entered manually.
|
|
||||||
|
|
||||||
| Format | Recommendation |
|
|
||||||
|---|---|
|
|
||||||
| Aspect ratio | 16:9 landscape (preferred) or 1:1 square |
|
|
||||||
| Resolution | ≥ 480p (720p+ is fine, no upper limit that matters) |
|
|
||||||
| Frame rate | Any — set via VHS_VIDEOINFO |
|
|
||||||
| Portrait (9:16) | Center-crop to square before extraction |
|
|
||||||
|
|
||||||
### 1.2 Extract visual features in ComfyUI
|
|
||||||
|
|
||||||
For each video clip you want to train on:
|
|
||||||
|
|
||||||
1. Load the video with a VHS LoadVideo node.
|
|
||||||
2. Connect it to **SelVA Feature Extractor**.
|
|
||||||
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
|
|
||||||
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
|
|
||||||
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
|
|
||||||
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
|
|
||||||
|
|
||||||
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
|
|
||||||
|
|
||||||
### Prompt guide
|
|
||||||
|
|
||||||
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
|
|
||||||
|
|
||||||
**Good prompts are specific about:**
|
|
||||||
- The sound source (what object is making the sound)
|
|
||||||
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
|
|
||||||
- The action producing the sound (if applicable)
|
|
||||||
|
|
||||||
| Sound | Weak prompt | Strong prompt |
|
|
||||||
|---|---|---|
|
|
||||||
| Dog bark | `dog` | `a large dog barking loudly` |
|
|
||||||
| Footsteps | `walking` | `heavy boots on a wooden floor` |
|
|
||||||
| Water | `water` | `water dripping into a metal bucket` |
|
|
||||||
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
|
|
||||||
| Door | `door` | `a heavy wooden door slamming shut` |
|
|
||||||
|
|
||||||
**Rules of thumb:**
|
|
||||||
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
|
|
||||||
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
|
|
||||||
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
|
|
||||||
|
|
||||||
### Masking note
|
|
||||||
|
|
||||||
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
|
|
||||||
|
|
||||||
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
|
|
||||||
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
|
|
||||||
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
|
|
||||||
|
|
||||||
### 1.3 Collect clean audio
|
|
||||||
|
|
||||||
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
|
|
||||||
|
|
||||||
```
|
|
||||||
dataset/my_sound/
|
|
||||||
dog_bark_001.npz ← from SelVA Feature Extractor
|
|
||||||
dog_bark_001.wav ← clean isolated audio recording
|
|
||||||
dog_bark_002.npz
|
|
||||||
dog_bark_002.wav
|
|
||||||
dog_bark_003.npz
|
|
||||||
dog_bark_003.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
|
|
||||||
|
|
||||||
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
|
|
||||||
|
|
||||||
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
|
|
||||||
|
|
||||||
### 1.4 Optional: prompts.txt
|
|
||||||
|
|
||||||
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
|
|
||||||
|
|
||||||
```
|
|
||||||
# One line per file: filename: prompt text
|
|
||||||
dog_bark.npz: a large dog barking aggressively
|
|
||||||
dog_bark_001.npz: a dog barking in the distance
|
|
||||||
```
|
|
||||||
|
|
||||||
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Train
|
|
||||||
|
|
||||||
### Option A — SelVA LoRA Trainer node (ComfyUI)
|
|
||||||
|
|
||||||
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option B — Command line
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_lora.py \
|
|
||||||
--data_dir dataset/my_sound \
|
|
||||||
--output_dir lora_output/my_sound \
|
|
||||||
--variant large_44k \
|
|
||||||
--selva_dir /path/to/ComfyUI/models/selva \
|
|
||||||
--rank 16 \
|
|
||||||
--steps 4000 \
|
|
||||||
--batch_size 4 \
|
|
||||||
--lr 1e-4
|
|
||||||
```
|
|
||||||
|
|
||||||
The script will:
|
|
||||||
1. Load the VAE, CLIP text encoder, and generator.
|
|
||||||
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
|
|
||||||
3. Train LoRA adapters for the specified number of steps.
|
|
||||||
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## CLI Reference
|
|
||||||
|
|
||||||
| Argument | Default | Description |
|
|
||||||
|---|---|---|
|
|
||||||
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
|
|
||||||
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
|
|
||||||
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
|
|
||||||
| `--selva_dir` | required | Path to SelVA model weights directory |
|
|
||||||
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
|
|
||||||
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
|
|
||||||
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
|
|
||||||
| `--lr` | `1e-4` | Learning rate |
|
|
||||||
| `--steps` | `2000` | Total training steps |
|
|
||||||
| `--warmup_steps` | `100` | Linear LR warmup steps |
|
|
||||||
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
|
|
||||||
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
|
|
||||||
| `--save_every` | `500` | Save a checkpoint every N steps |
|
|
||||||
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
|
|
||||||
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
|
||||||
| `--seed` | `42` | Random seed |
|
|
||||||
| `--timestep_mode` | `uniform` | Timestep sampling: `uniform`, `logit_normal`, or `curriculum` |
|
|
||||||
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` / `curriculum` |
|
|
||||||
| `--curriculum_switch` | `0.6` | Fraction of steps to use logit_normal before switching to uniform. Only with `curriculum` |
|
|
||||||
| `--lora_dropout` | `0.0` | Dropout on the LoRA path only. `0.05`–`0.1` helps regularize on small datasets |
|
|
||||||
| `--lora_plus_ratio` | `1.0` | LoRA+ LR ratio: `lr_B = lr × ratio`. `1.0` = standard LoRA, `16.0` = LoRA+ |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Load the adapter in ComfyUI
|
|
||||||
|
|
||||||
Connect **SelVA LoRA Loader** between the model loader and the sampler:
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|---|---|
|
|
||||||
| `model` | SELVA_MODEL from the model loader |
|
|
||||||
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
|
|
||||||
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
|
|
||||||
|
|
||||||
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
|
|
||||||
|
|
||||||
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tuning Guide
|
|
||||||
|
|
||||||
### Clip length
|
|
||||||
|
|
||||||
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
|
|
||||||
|
|
||||||
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
|
|
||||||
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
|
|
||||||
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
|
|
||||||
|
|
||||||
**Practical recommendations:**
|
|
||||||
|
|
||||||
| Sound type | Clip strategy |
|
|
||||||
|---|---|
|
|
||||||
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
|
|
||||||
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 2–3 times per clip |
|
|
||||||
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
|
|
||||||
| Sound with a clear onset (explosion, splash) | Put the onset at ~1–2 s from the start, not at 0 s — gives the model context |
|
|
||||||
|
|
||||||
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
|
|
||||||
|
|
||||||
### How many clips do I need?
|
|
||||||
|
|
||||||
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
|
|
||||||
|
|
||||||
| Dataset size | Scenario | Expected result |
|
|
||||||
|---|---|---|
|
|
||||||
| **5–10 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
|
|
||||||
| **15–30 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
|
|
||||||
| **30–60 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
|
|
||||||
| **60–150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
|
|
||||||
| **150–300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
|
|
||||||
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
|
|
||||||
|
|
||||||
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
|
|
||||||
|
|
||||||
### Batch size
|
|
||||||
|
|
||||||
| Batch size | VRAM (large_44k) | Use case |
|
|
||||||
|---|---|---|
|
|
||||||
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
|
|
||||||
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
|
|
||||||
| `8` | ~15 GB | Better convergence on larger datasets |
|
|
||||||
| `16` | ~20 GB | Best gradient quality when VRAM allows |
|
|
||||||
|
|
||||||
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
|
|
||||||
|
|
||||||
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
|
|
||||||
|
|
||||||
### Rank
|
|
||||||
|
|
||||||
| Rank | Use case |
|
|
||||||
|---|---|
|
|
||||||
| `8` | Fine details on a sound the model already knows well |
|
|
||||||
| `16` | Default — good balance of capacity and VRAM |
|
|
||||||
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
|
|
||||||
|
|
||||||
Higher rank increases VRAM usage and overfitting risk on small datasets.
|
|
||||||
|
|
||||||
### Steps
|
|
||||||
|
|
||||||
With `batch_size=4` as the default, these are rough guidelines:
|
|
||||||
|
|
||||||
| Dataset size | Recommended steps |
|
|
||||||
|---|---|
|
|
||||||
| 10–20 clips | 2000–4000 |
|
|
||||||
| 20–50 clips | 4000–8000 |
|
|
||||||
| 50+ clips | 6000–15000 |
|
|
||||||
|
|
||||||
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
|
|
||||||
|
|
||||||
### Learning rate
|
|
||||||
|
|
||||||
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
|
|
||||||
|
|
||||||
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
|
|
||||||
|
|
||||||
### Target layers
|
|
||||||
|
|
||||||
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
|
|
||||||
|
|
||||||
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--target attn.qkv linear1
|
|
||||||
```
|
|
||||||
|
|
||||||
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
|
|
||||||
|
|
||||||
### Timestep sampling mode
|
|
||||||
|
|
||||||
Controls how training timesteps are sampled at each step.
|
|
||||||
|
|
||||||
`uniform` (default) samples all timesteps equally — equivalent to original MMAudio training.
|
|
||||||
|
|
||||||
`logit_normal` concentrates more steps near t=0.5 via `sigmoid(N(0, σ))`. This is the semantically rich mid-noise region. Consistently reaches a lower loss floor but the perceptual improvement on small datasets is marginal.
|
|
||||||
|
|
||||||
`curriculum` uses logit_normal for the first `curriculum_switch` fraction of steps (default 60%), then switches to uniform for the remainder. The motivation: logit_normal accelerates early structure learning but undertrains the high-t boundary region; uniform then fills in the fine detail. A switch message is logged when the transition happens.
|
|
||||||
|
|
||||||
| Mode | When to use |
|
|
||||||
|---|---|
|
|
||||||
| `uniform` (default) | Baseline — safe, equivalent to original training |
|
|
||||||
| `logit_normal` | When you want a lower loss floor; marginal on small datasets |
|
|
||||||
| `curriculum` | Experimental — may improve convergence quality on small datasets |
|
|
||||||
|
|
||||||
The `logit_normal_sigma` parameter controls the width of the logit-normal distribution (used by both `logit_normal` and the first phase of `curriculum`):
|
|
||||||
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
|
|
||||||
- σ=0.5: sharper peak, less coverage of extremes
|
|
||||||
- σ=2.0: broader, approaches uniform
|
|
||||||
|
|
||||||
### LoRA dropout
|
|
||||||
|
|
||||||
`lora_dropout` applies dropout to the input of the LoRA path (not the frozen base linear). It regularizes the low-rank update without disturbing pretrained weights — helpful on small datasets where the LoRA would otherwise overfit to the training clips.
|
|
||||||
|
|
||||||
| Value | Use case |
|
|
||||||
|---|---|
|
|
||||||
| `0.0` (default) | No regularization — fine for 30+ clips |
|
|
||||||
| `0.05` | Light regularization — recommended starting point on 10–20 clips |
|
|
||||||
| `0.1` | Stronger regularization — use if loss plateaus but audio is still noisy |
|
|
||||||
|
|
||||||
Dropout is not saved in the adapter file — it only affects training. Loading the adapter at inference does not require setting dropout.
|
|
||||||
|
|
||||||
### LoRA+ (asymmetric learning rate)
|
|
||||||
|
|
||||||
`lora_plus_ratio` splits the learning rate between LoRA A and B matrices: `lr_B = lr × ratio`. The B matrix is the output-side projection and benefits from a higher LR. Setting ratio to 16 enables the LoRA+ scheme from arXiv:2402.12354.
|
|
||||||
|
|
||||||
| Ratio | Effect |
|
|
||||||
|---|---|
|
|
||||||
| `1.0` (default) | Standard LoRA — identical A and B learning rates |
|
|
||||||
| `4.0` | Mild asymmetry |
|
|
||||||
| `16.0` | LoRA+ — faster convergence, especially on early steps |
|
|
||||||
|
|
||||||
LoRA+ is orthogonal to dropout and curriculum sampling — all three can be combined.
|
|
||||||
|
|
||||||
### Adapter strength at inference
|
|
||||||
|
|
||||||
| Strength | Effect |
|
|
||||||
|---|---|
|
|
||||||
| `0.5–0.7` | Conservative — blends adapter with base model, less noise |
|
|
||||||
| `1.0` | Full adapter strength (default) |
|
|
||||||
| `>1.0` | Exaggerated effect, may introduce artifacts |
|
|
||||||
|
|
||||||
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.6–0.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
|
|
||||||
|
|
||||||
### Loss interpretation
|
|
||||||
|
|
||||||
A typical loss curve:
|
|
||||||
- Starts around `0.8–1.0`
|
|
||||||
- Should reach `0.55–0.65` after convergence on a clean sound class with 10–30 clips
|
|
||||||
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
|
|
||||||
- Below `0.1` on a small dataset means overfitting
|
|
||||||
|
|
||||||
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
|
|
||||||
|
|
||||||
### Precision
|
|
||||||
|
|
||||||
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Output files
|
|
||||||
|
|
||||||
```
|
|
||||||
lora_output/my_sound/
|
|
||||||
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
|
|
||||||
adapter_step01000.pt
|
|
||||||
...
|
|
||||||
adapter_final.pt ← final adapter with embedded metadata (inference only)
|
|
||||||
meta.json ← human-readable metadata
|
|
||||||
sample_step00500.wav ← quick eval sample at each checkpoint
|
|
||||||
loss_raw.png ← raw loss curve
|
|
||||||
loss_smoothed.png ← EMA-smoothed loss curve
|
|
||||||
```
|
|
||||||
|
|
||||||
`adapter_final.pt` format:
|
|
||||||
```python
|
|
||||||
{
|
|
||||||
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
|
|
||||||
"meta": {
|
|
||||||
"variant": "large_44k",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 16.0,
|
|
||||||
"target": ["attn.qkv"],
|
|
||||||
"steps": 2000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
**`No layers matched target=...`**
|
|
||||||
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
|
|
||||||
|
|
||||||
**`No .npz files found in ...`**
|
|
||||||
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
|
|
||||||
|
|
||||||
**`No audio file found for clip.npz`**
|
|
||||||
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
|
|
||||||
|
|
||||||
**The sound is audible but there is white noise on top**
|
|
||||||
Lower the adapter strength to `0.6–0.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
|
|
||||||
|
|
||||||
**LoRA appears to have no effect**
|
|
||||||
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
|
|
||||||
|
|
||||||
**Loss does not decrease**
|
|
||||||
- Increase `batch_size` for more stable gradients.
|
|
||||||
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
|
|
||||||
- Check that the audio files are clean and actually contain the target sound.
|
|
||||||
- Check that the `.npz` features were extracted with a relevant prompt.
|
|
||||||
|
|
||||||
**Loss explodes or NaN**
|
|
||||||
- Lower the learning rate (`5e-5`).
|
|
||||||
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
|
|
||||||
|
|
||||||
**Loss plateaus early (above 0.7)**
|
|
||||||
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Observations (work in progress)
|
|
||||||
|
|
||||||
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
|
|
||||||
|
|
||||||
### Precision and batch size
|
|
||||||
|
|
||||||
| Config | Smoothed loss at step 2000 | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
|
|
||||||
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 6000–8000 at ~0.59 |
|
|
||||||
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
|
|
||||||
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
|
|
||||||
|
|
||||||
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
|
|
||||||
|
|
||||||
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
|
|
||||||
|
|
||||||
### logit_normal vs uniform
|
|
||||||
|
|
||||||
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
|
|
||||||
|
|
||||||
### White noise
|
|
||||||
|
|
||||||
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
|
|
||||||
- Too few clips for the model to confidently predict the target sound
|
|
||||||
- Imprecise extraction prompts producing unfocused sync features
|
|
||||||
- Missing mask when multiple objects are in frame
|
|
||||||
|
|
||||||
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.0–3.5 or adapter strength to 0.6–0.7 helps at inference.
|
|
||||||
@@ -58,7 +58,7 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
|
|||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `model` | From SelVA Model Loader (or any loader/loader chain) |
|
| `model` | From SelVA Model Loader |
|
||||||
| `features` | From SelVA Feature Extractor |
|
| `features` | From SelVA Feature Extractor |
|
||||||
| `prompt` | Text description — leave empty to use the prompt stored in features |
|
| `prompt` | Text description — leave empty to use the prompt stored in features |
|
||||||
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
|
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
|
||||||
@@ -66,261 +66,22 @@ Generates audio from video features. Runs the rectified flow ODE with classifier
|
|||||||
| `steps` | Sampling steps (default: 25) |
|
| `steps` | Sampling steps (default: 25) |
|
||||||
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
|
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
|
||||||
| `seed` | RNG seed |
|
| `seed` | RNG seed |
|
||||||
| `normalize` | RMS-normalize output to `target_lufs` (default: true) |
|
| `normalize` | Peak-normalize output to [-1, 1] (default: true) |
|
||||||
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
|
|
||||||
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
|
|
||||||
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
|
|
||||||
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
|
|
||||||
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
|
|
||||||
|
|
||||||
**Output:** `AUDIO`
|
**Output:** `AUDIO`
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### SelVA LoRA Loader
|
## Workflow
|
||||||
|
|
||||||
Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL from Model Loader |
|
|
||||||
| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
|
|
||||||
| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
|
|
||||||
|
|
||||||
**Output:** `model` (SELVA_MODEL with adapter injected)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Trainer
|
|
||||||
|
|
||||||
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
|
|
||||||
|
|
||||||
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Scheduler
|
|
||||||
|
|
||||||
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `experiments_file` | Path to JSON sweep config |
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Skip Experiment
|
|
||||||
|
|
||||||
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
|
|
||||||
|
|
||||||
**Output:** `flag_path` (STRING)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Evaluator
|
|
||||||
|
|
||||||
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Dataset Browser
|
|
||||||
|
|
||||||
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
|
|
||||||
|
|
||||||
**Outputs:** video path, audio path, frames directory, label, total count
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA VAE Roundtrip
|
|
||||||
|
|
||||||
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `audio` | AUDIO to test |
|
|
||||||
|
|
||||||
**Output:** `audio_reconstructed` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA HF Smoother
|
|
||||||
|
|
||||||
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
|
|
||||||
|
|
||||||
**Output:** `audio` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Spectral Matcher
|
|
||||||
|
|
||||||
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
|
|
||||||
|
|
||||||
**Output:** `audio` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Textual Inversion Trainer
|
|
||||||
|
|
||||||
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
|
|
||||||
|
|
||||||
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
|
|
||||||
|
|
||||||
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Textual Inversion Loader
|
|
||||||
|
|
||||||
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
|
|
||||||
|
|
||||||
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA TI Scheduler
|
|
||||||
|
|
||||||
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Activation Steering Extractor
|
|
||||||
|
|
||||||
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `data_dir` | Directory with `.npz` feature files |
|
|
||||||
| `output_path` | Where to save `steering_vectors.pt` |
|
|
||||||
| `n_samples` | Clips to average over (default: 16) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
|
|
||||||
**Output:** `steering_path` (STRING)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Activation Steering Loader
|
|
||||||
|
|
||||||
Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
|
|
||||||
|
|
||||||
**Output:** `steering_vectors` (STEERING_VECTORS)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA BigVGAN Trainer
|
|
||||||
|
|
||||||
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
|
|
||||||
|
|
||||||
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `data_dir` | Directory with target-style audio files (searched recursively) |
|
|
||||||
| `output_path` | Where to save the fine-tuned vocoder `.pt` |
|
|
||||||
| `train_mode` | `snake_alpha_only` (default) or `all_params` |
|
|
||||||
| `steps` | Training steps (default: 2000) |
|
|
||||||
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
|
|
||||||
| `batch_size` | Clips per step (default: 4) |
|
|
||||||
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
|
|
||||||
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
|
|
||||||
| `save_every` | Checkpoint interval in steps (default: 500) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
|
|
||||||
|
|
||||||
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
|
|
||||||
|
|
||||||
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA BigVGAN Loader
|
|
||||||
|
|
||||||
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL from Model Loader |
|
|
||||||
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
|
|
||||||
|
|
||||||
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA DITTO Optimizer
|
|
||||||
|
|
||||||
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
|
|
||||||
|
|
||||||
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `features` | From SelVA Feature Extractor |
|
|
||||||
| `prompt` | Sound description (leave empty to use features prompt) |
|
|
||||||
| `negative_prompt` | Sounds to suppress |
|
|
||||||
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
|
|
||||||
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
|
|
||||||
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
|
|
||||||
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
|
|
||||||
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
|
|
||||||
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
|
|
||||||
| `steps` | Euler steps for the final generation pass (default: 25) |
|
|
||||||
| `cfg_strength` | CFG scale (default: 4.5) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
| `normalize` | *(optional)* RMS normalize output (default: true) |
|
|
||||||
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
|
|
||||||
|
|
||||||
**Output:** `AUDIO`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Workflows
|
|
||||||
|
|
||||||
### Basic generation
|
|
||||||
|
|
||||||
```
|
```
|
||||||
VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
|
VHS LoadVideo ──► SelVA Feature Extractor ──────────────────────► SelVA Sampler ──► Save Audio
|
||||||
│ (video_info) ▲
|
│ (video_info) ─► (fps auto) ▲
|
||||||
│ (features) ──────────────────────────────────►│
|
│ (features) ────────────────────────────────────►│
|
||||||
│ (prompt) ────────────────────────────────────►│
|
│ (prompt) ──────────────────────────────────────►│
|
||||||
```
|
```
|
||||||
|
|
||||||
### DITTO style transfer (recommended first approach)
|
Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction.
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
|
|
||||||
▲
|
|
||||||
SelVA Feature Extractor ──(features)────────────────────────────────────►│
|
|
||||||
(prompt) ──────────────────────────────────────►│
|
|
||||||
BJ reference_dir ───────────────────────────────────────────────────────►│
|
|
||||||
```
|
|
||||||
|
|
||||||
No training required. Each run optimizes x₀ independently for the current video and reference set.
|
|
||||||
|
|
||||||
### Vocoder fine-tuning
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
|
|
||||||
▲
|
|
||||||
BJ audio clips ──(data_dir)──►│
|
|
||||||
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
|
|
||||||
▲ ▲
|
|
||||||
checkpoint .pt SelVA Feature Extractor
|
|
||||||
```
|
|
||||||
|
|
||||||
### LoRA training
|
|
||||||
|
|
||||||
See [LORA_TRAINING.md](LORA_TRAINING.md).
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -366,15 +127,8 @@ The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available,
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Style Transfer
|
|
||||||
|
|
||||||
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
|
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
|
||||||
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
|
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
|
||||||
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
|
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
|
||||||
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
|
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
# Style Transfer for SelVA
|
|
||||||
|
|
||||||
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why standard fine-tuning is hard
|
|
||||||
|
|
||||||
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~10–15 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
|
|
||||||
|
|
||||||
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.21–0.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
|
|
||||||
|
|
||||||
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
|
|
||||||
|
|
||||||
The approaches below are ordered by expected quality and ease of use.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 1 — DITTO (recommended first try)
|
|
||||||
|
|
||||||
**Node: SelVA DITTO Optimizer**
|
|
||||||
|
|
||||||
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
|
|
||||||
|
|
||||||
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
|
|
||||||
|
|
||||||
**How it works:**
|
|
||||||
|
|
||||||
For each video clip you want to process:
|
|
||||||
1. Run SelVA Feature Extractor as usual.
|
|
||||||
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
|
|
||||||
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
|
|
||||||
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
|
|
||||||
▲
|
|
||||||
SelVA Feature Extractor ──(features)────────────────────────►│
|
|
||||||
(prompt) ──────────────────────────►│
|
|
||||||
BJ clips ───────────────────────────(reference_dir) ─────────►│
|
|
||||||
```
|
|
||||||
|
|
||||||
**Tuning guide:**
|
|
||||||
|
|
||||||
| Parameter | Starting value | When to adjust |
|
|
||||||
|---|---|---|
|
|
||||||
| `n_opt_steps` | 50 | Increase to 100–200 if style shift is too subtle |
|
|
||||||
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
|
|
||||||
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
|
|
||||||
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
|
|
||||||
| `style_weight` | 1.0 | Increase to 2–5 for stronger BJ character; watch for incoherence |
|
|
||||||
|
|
||||||
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~4–6 GB additional VRAM over baseline inference.
|
|
||||||
|
|
||||||
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 5–15 minutes depending on GPU.
|
|
||||||
|
|
||||||
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 2 — Vocoder Fine-tuning
|
|
||||||
|
|
||||||
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
|
|
||||||
|
|
||||||
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
|
|
||||||
|
|
||||||
### Why plain mel L1 loss fails
|
|
||||||
|
|
||||||
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 3–8 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
|
|
||||||
|
|
||||||
### `snake_alpha_only` mode (default, recommended)
|
|
||||||
|
|
||||||
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
|
|
||||||
|
|
||||||
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
|
|
||||||
|
|
||||||
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
|
|
||||||
|
|
||||||
### `all_params` mode with discriminator
|
|
||||||
|
|
||||||
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
|
|
||||||
|
|
||||||
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
|
|
||||||
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
|
|
||||||
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
|
|
||||||
|
|
||||||
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
|
|
||||||
|
|
||||||
### Workflow
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
|
|
||||||
▲
|
|
||||||
BJ audio clips ──(data_dir)──►│
|
|
||||||
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
|
|
||||||
▲ ▲
|
|
||||||
bigvgan_bj.pt SelVA Feature Extractor
|
|
||||||
```
|
|
||||||
|
|
||||||
### Tuning guide
|
|
||||||
|
|
||||||
| Parameter | Default | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
|
|
||||||
| `steps` | 2000 | 1000–2000 for snake_alpha_only; 3000–5000 for all_params |
|
|
||||||
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
|
|
||||||
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
|
|
||||||
| `batch_size` | 4 | 4–8 for stable gradients |
|
|
||||||
| `segment_seconds` | 1.0 | 1–2 s segments recommended |
|
|
||||||
|
|
||||||
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 3 — DITTO + Vocoder (combined)
|
|
||||||
|
|
||||||
Stack both:
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
|
|
||||||
▲ ▲
|
|
||||||
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
|
|
||||||
```
|
|
||||||
|
|
||||||
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What doesn't work (and why)
|
|
||||||
|
|
||||||
### Standard LoRA
|
|
||||||
|
|
||||||
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
|
|
||||||
|
|
||||||
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
|
|
||||||
|
|
||||||
### Textual inversion
|
|
||||||
|
|
||||||
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
|
|
||||||
|
|
||||||
### Activation steering (global mean difference)
|
|
||||||
|
|
||||||
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 3–6 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Reference dataset preparation
|
|
||||||
|
|
||||||
Use the same audio clips for both DITTO and vocoder fine-tuning:
|
|
||||||
|
|
||||||
- **Minimum:** 20–30 clips. DITTO works from 5+; vocoder benefits from 40+.
|
|
||||||
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
|
|
||||||
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
|
|
||||||
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
|
|
||||||
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
|
|
||||||
|
|
||||||
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
# Audio Dataset Pipeline for Generative Model Training
|
|
||||||
|
|
||||||
Research notes on audio cleaning, augmentation, and quality metrics for LoRA fine-tuning of MMAudio/SelVA. Based on papers and tooling survey (April 2026).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Core Principle
|
|
||||||
|
|
||||||
Augmentation for generative models ≠ augmentation for classifiers.
|
|
||||||
The goal is **not invariance** — it is expanding the training manifold so the model learns the distribution of a sound rather than memorizing a fixed set of waveforms.
|
|
||||||
|
|
||||||
With 10 clips, velocity field collapse (arXiv:2410.23594) is mathematically expected: the flow-matching model memorizes the training trajectories instead of generalizing. More diverse data is the only real fix.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Recommended Pipeline
|
|
||||||
|
|
||||||
### Step 1 — Quality Screening
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Clipping check
|
|
||||||
clip_ratio = np.sum(np.abs(audio) >= 0.99) / len(audio) # flag if > 0.1%
|
|
||||||
|
|
||||||
# DC offset check + removal
|
|
||||||
dc = np.mean(audio)
|
|
||||||
audio -= dc
|
|
||||||
|
|
||||||
# LUFS normalization to -14 LUFS (essential for training consistency)
|
|
||||||
# pip install pyloudnorm
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
loudness = meter.integrated_loudness(audio)
|
|
||||||
audio = pyln.normalize.loudness(audio, loudness, -14.0)
|
|
||||||
# Or via ffmpeg: ffmpeg -af loudnorm=I=-14:LRA=7:TP=-1
|
|
||||||
|
|
||||||
# DNSMOS quality gate (discard if OVRL < 3.5 for training; < 2.5 is unusable)
|
|
||||||
# from Microsoft DNS-Challenge repo
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2 — Cleaning
|
|
||||||
|
|
||||||
| Tool | Install | Use |
|
|
||||||
|---|---|---|
|
|
||||||
| **AudioSep** | `pip install audiosep` | Isolate target sound from background — most impactful tool |
|
|
||||||
| **noisereduce** | `pip install noisereduce` | Light stationary/non-stationary denoising, preserves character |
|
|
||||||
| **librosa** | `pip install librosa` | Silence trimming: `librosa.effects.trim(audio, top_db=30)` |
|
|
||||||
| **torchaudio.transforms.Fade** | (torchaudio) | Prevent click artifacts at clip edges |
|
|
||||||
| **DeepFilterNet** | `pip install deepfilternet` | Heavy denoising — good for speech, may alter tonal sounds |
|
|
||||||
|
|
||||||
**AudioSep usage:**
|
|
||||||
```python
|
|
||||||
from audiosep import AudioSep
|
|
||||||
model = AudioSep.from_pretrained("audio-agi/audiosep")
|
|
||||||
# ~1.5 GB checkpoint, ~4 GB VRAM
|
|
||||||
model.inference(audio_path, "a dog barking loudly", output_path)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3 — Waveform Augmentation (10 clips → 50–100)
|
|
||||||
|
|
||||||
Apply stochastically per clip:
|
|
||||||
|
|
||||||
| Transform | Params | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| **PitchShift** | ±1–3 semitones | 3 variants per clip. Limit to ±1 st for tonal/pitched sounds |
|
|
||||||
| **ApplyImpulseResponse** | 5 different RIRs | 5 variants per clip — EchoThief (~150 free IRs) or pyroomacoustics |
|
|
||||||
| **LoudnessNormalization** | ±2 dB random | Subtle level variation |
|
|
||||||
| **SevenBandParametricEQ** | ±3 dB | Gentle spectral variation |
|
|
||||||
| **TimeStretch** | 0.9–1.1× only | Do NOT use 2× to pad short clips — breaks video sync |
|
|
||||||
|
|
||||||
```python
|
|
||||||
# pip install audiomentations pedalboard pyroomacoustics
|
|
||||||
import audiomentations as A
|
|
||||||
|
|
||||||
augment = A.Compose([
|
|
||||||
A.PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
|
|
||||||
A.ApplyImpulseResponse(ir_paths="path/to/irs/", p=0.5),
|
|
||||||
A.SevenBandParametricEQ(min_gain_db=-3, max_gain_db=3, p=0.3),
|
|
||||||
A.LoudnessNormalization(min_lufs=-16, max_lufs=-12, p=0.5),
|
|
||||||
A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3),
|
|
||||||
])
|
|
||||||
audio_aug = augment(samples=audio, sample_rate=sr)
|
|
||||||
```
|
|
||||||
|
|
||||||
**RIR sources:**
|
|
||||||
- EchoThief: ~150 free real-world IRs (churches, caves, parking garages)
|
|
||||||
- pyroomacoustics: synthetic room simulation, fully controllable
|
|
||||||
|
|
||||||
### Step 4 — Latent Augmentation (at training time)
|
|
||||||
|
|
||||||
After VAE encoding:
|
|
||||||
|
|
||||||
**Latent mixup** between same-category pairs:
|
|
||||||
```python
|
|
||||||
# Mix latents BEFORE flow-matching noise is added
|
|
||||||
# Only mix clips from the same sound category — cross-category mixing produces garbage
|
|
||||||
lam = torch.distributions.Beta(0.4, 0.4).sample()
|
|
||||||
z_mix = lam * z1 + (1 - lam) * z2
|
|
||||||
```
|
|
||||||
With 10 clips: C(10,2) = 45 possible pairs → significant expansion without new recordings.
|
|
||||||
|
|
||||||
**Small Gaussian noise:**
|
|
||||||
```python
|
|
||||||
z_noised = z + torch.randn_like(z) * 0.02 * z.std()
|
|
||||||
```
|
|
||||||
Prevents trivial memorization of exact latent coordinates.
|
|
||||||
|
|
||||||
MusicLDM (arXiv:2308.01546) shows latent mixup > waveform mixup for generative quality.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Transforms to AVOID for Generative Training
|
|
||||||
|
|
||||||
| Transform | Why |
|
|
||||||
|---|---|
|
|
||||||
| ClippingDistortion, BitCrush, TanhDistortion, Mp3Compression | Model learns the artifact |
|
|
||||||
| Reverse | Breaks temporal structure for video-to-audio task |
|
|
||||||
| TimeMask (creating silence gaps) | Unnatural — model learns to produce silence |
|
|
||||||
| TimeStretch > 1.3× | Phase vocoder artifacts become part of the target distribution |
|
|
||||||
| Heavy background noise (< 15 dB SNR) | Model learns to reproduce the noise |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quality Metrics
|
|
||||||
|
|
||||||
| Metric | Tool | Threshold |
|
|
||||||
|---|---|---|
|
|
||||||
| DNSMOS P.835 (SIG/BAK/OVRL) | Microsoft DNS-Challenge | OVRL > 3.5 for training |
|
|
||||||
| LUFS | pyloudnorm | Normalize all clips to -14 LUFS |
|
|
||||||
| WADA-SNR | (standalone) | No-reference SNR estimate |
|
|
||||||
| Clipping ratio | NumPy | Flag if > 0.1% of samples at ±0.99 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tool Reference
|
|
||||||
|
|
||||||
| Tool | Install | Purpose |
|
|
||||||
|---|---|---|
|
|
||||||
| audiomentations | `pip install audiomentations` | Primary augmentation library |
|
|
||||||
| pedalboard | `pip install pedalboard` | Higher quality pitch shift, IR convolution |
|
|
||||||
| AudioSep | `pip install audiosep` | Source separation / isolation |
|
|
||||||
| noisereduce | `pip install noisereduce` | Non-stationary denoising |
|
|
||||||
| DeepFilterNet | `pip install deepfilternet` | Heavy denoising (speech-optimized) |
|
|
||||||
| pyloudnorm | `pip install pyloudnorm` | LUFS normalization |
|
|
||||||
| Silero VAD | `pip install silero-vad` | Voice/silence detection |
|
|
||||||
| pyroomacoustics | `pip install pyroomacoustics` | Synthetic RIR generation |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Integration with PrismAudio / SelVA
|
|
||||||
|
|
||||||
No established ComfyUI audio preprocessing ecosystem as of early 2026. Build thin wrapper nodes around the tools above. PrismAudio already has all required patterns (subprocess isolation, AUDIO type transport).
|
|
||||||
|
|
||||||
**Target node set:**
|
|
||||||
- `SelVA Dataset Cleaner` — wraps noisereduce + LUFS normalization + trim + DNSMOS gate
|
|
||||||
- `SelVA Dataset Augmenter` — wraps audiomentations Compose pipeline
|
|
||||||
|
|
||||||
Steps 1–3 are preprocessing (run once before feature extraction).
|
|
||||||
Step 4 (latent mixup) is a training loop modification — integrate into `selva_lora_trainer.py`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Papers
|
|
||||||
|
|
||||||
| Paper | ArXiv | Finding |
|
|
||||||
|---|---|---|
|
|
||||||
| MusicLDM | 2308.01546 | Latent mixup > waveform mixup for generative quality |
|
|
||||||
| EDMSound | 2311.08667 | Memorization documented — same failure mode as 10-clip training |
|
|
||||||
| Synthio | 2410.02056 | Synthetic audio as augmentation data (ICLR 2025) |
|
|
||||||
| HunyuanVideo-Foley | 2508.16930 | V2A data pipeline at scale (100K hrs) |
|
|
||||||
| FM memorization | 2410.23594 | Velocity field collapse theory — proves early overfitting on small datasets |
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
# AudioX vs SelVA — Evaluation
|
|
||||||
|
|
||||||
AudioX (arXiv:2503.10522, ICLR 2026) is a unified multimodal audio generation model from HKUST.
|
|
||||||
This document compares it against SelVA/MMAudio and assesses the cost of adding it to PrismAudio.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quick Decision Guide
|
|
||||||
|
|
||||||
| Situation | Use |
|
|
||||||
|---|---|
|
|
||||||
| Video → realistic sound effects | **SelVA** — faster, purpose-built, MIT license |
|
|
||||||
| Music generation from video or text | **AudioX** — SelVA cannot do this |
|
|
||||||
| Audio inpainting / music continuation | **AudioX** — SelVA cannot do this |
|
|
||||||
| LoRA fine-tuning on a custom sound | **SelVA** — full training infrastructure already exists |
|
|
||||||
| Variable output duration | **AudioX** — SelVA is fixed at 8 s |
|
|
||||||
| Inference speed matters | **SelVA** — 25 steps vs 250 (10× faster) |
|
|
||||||
| Non-commercial research | Either |
|
|
||||||
| Any commercial use | **SelVA only** — AudioX is CC-BY-NC-4.0 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
| Dimension | SelVA (MMAudio) | AudioX-MAF |
|
|
||||||
|---|---|---|
|
|
||||||
| Core paradigm | Flow matching | Diffusion (k-diffusion / DPM++) |
|
|
||||||
| Inference steps | 25 ODE steps (Euler) | 250 diffusion steps (DPM++ 3M SDE) |
|
|
||||||
| Sample rate | 44.1 kHz (large) / 16 kHz (small) | 48 kHz (fixed) |
|
|
||||||
| Generator | MM-DiT, velocity prediction | ContinuousMMDiTTransformer |
|
|
||||||
| Video encoder | Synchformer | Synchformer (AudioX custom re-impl, same concept) |
|
|
||||||
| VAE / codec | DAC (descript-audio-codec) | DAC + AudioCraft options |
|
|
||||||
| Text encoder | T5-large | T5 (configurable small → XXL) |
|
|
||||||
| Video-audio fusion | Cross-attention in MM-DiT | MAF: dual-projection (dim alignment + seq length alignment) |
|
|
||||||
| Output duration | Fixed 8 s | Configurable via `sample_size` (default ~44 s at 48kHz) |
|
|
||||||
| Training data | ~2 M samples (MMAudio paper) | 7 M samples (IF-caps dataset, curated) |
|
|
||||||
| License | MIT | CC-BY-NC-4.0 |
|
|
||||||
|
|
||||||
**MAF (Multimodal Adaptive Fusion):** AudioX's key architectural contribution. Instead of directly
|
|
||||||
concatenating multimodal tokens into the DiT's cross-attention, MAF projects each modality to
|
|
||||||
match the latent's sequence length via a dedicated linear + transposed-conv stack, then applies
|
|
||||||
`MMDitSingleBlock` layers for cross-modal fusion. The paper reports this improves cross-modal
|
|
||||||
alignment particularly for video-to-audio tasks.
|
|
||||||
|
|
||||||
**Flow matching vs diffusion:** Flow matching (SelVA) trains a single velocity field to move
|
|
||||||
directly from noise to data along a straight trajectory — this is why 25 steps suffice. Standard
|
|
||||||
diffusion (AudioX) approximates a longer stochastic path, requiring 250 steps for quality output.
|
|
||||||
This is not a quality difference per se; flow matching is simply more efficient.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Capabilities
|
|
||||||
|
|
||||||
| Task | SelVA | AudioX |
|
|
||||||
|---|---|---|
|
|
||||||
| Video → sound effects | ✓ (primary use case) | ✓ |
|
|
||||||
| Text → sound effects | Partial (T5 conditions quality but not primary) | ✓ (strong benchmark scores) |
|
|
||||||
| Video → music | ✗ | ✓ |
|
|
||||||
| Text → music | ✗ | ✓ |
|
|
||||||
| Audio inpainting | ✗ | ✓ (mask_args parameter) |
|
|
||||||
| Music continuation | ✗ | ✓ (init_audio parameter) |
|
|
||||||
| Variable output duration | ✗ (fixed 8 s) | ✓ |
|
|
||||||
| Multiple input modalities simultaneously | Partial | ✓ (text + video + audio at once) |
|
|
||||||
|
|
||||||
AudioX benchmarks claim superior results on text-to-audio (AudioCaps) and text-to-music
|
|
||||||
(MusicCaps) vs prior models. Video-to-audio comparison against MMAudio specifically is not
|
|
||||||
prominently featured in the paper. Perceptual evaluation confirms this: AudioX does not sound
|
|
||||||
noticeably better than SelVA on video-to-audio tasks. AudioX's advantage is **breadth**
|
|
||||||
(music, inpainting, variable duration), not raw video-to-audio quality.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Integration Cost
|
|
||||||
|
|
||||||
Adding AudioX inference-only nodes to PrismAudio would require:
|
|
||||||
|
|
||||||
### New nodes (3 files)
|
|
||||||
|
|
||||||
```
|
|
||||||
nodes/
|
|
||||||
audiox_model_loader.py AUDIOX_MODEL loader — get_pretrained_model("HKUSTAudio/AudioX-MAF")
|
|
||||||
audiox_sampler.py wraps generate_diffusion_cond(), inputs: model + text + video + audio
|
|
||||||
audiox_feature_extractor.py optional — pre-extract Synchformer sync features (caching)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install git+https://github.com/ZeyueT/AudioX.git
|
|
||||||
```
|
|
||||||
|
|
||||||
New dependencies not currently in PrismAudio:
|
|
||||||
- `pytorch-lightning==2.4.0`
|
|
||||||
- `k-diffusion==0.1.1`
|
|
||||||
- `v-diffusion-pytorch==0.0.2`
|
|
||||||
- `descript-audio-codec==1.0.0` (already used by SelVA — no conflict, same package)
|
|
||||||
- `gradio==4.44.1` (optional — only for the upstream Gradio UI)
|
|
||||||
|
|
||||||
Model weights: `HKUSTAudio/AudioX-MAF` on HuggingFace (~several GB).
|
|
||||||
|
|
||||||
### Inference API surface
|
|
||||||
|
|
||||||
```python
|
|
||||||
from audiox import get_pretrained_model
|
|
||||||
from audiox.inference.generation import generate_diffusion_cond
|
|
||||||
|
|
||||||
model, config = get_pretrained_model("HKUSTAudio/AudioX-MAF")
|
|
||||||
|
|
||||||
output = generate_diffusion_cond(
|
|
||||||
model,
|
|
||||||
steps=250,
|
|
||||||
cfg_scale=6.0,
|
|
||||||
conditioning={
|
|
||||||
"text_prompt": "a dog barking",
|
|
||||||
"video_prompt": {"video": frames_tensor, "sync_features": sync_feat},
|
|
||||||
"seconds_total": 8.0,
|
|
||||||
},
|
|
||||||
sample_size=384000, # 8 s at 48kHz
|
|
||||||
sample_rate=48000,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
# output: torch.Tensor (batch, channels, num_samples) float32 [-1, 1]
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## LoRA Training
|
|
||||||
|
|
||||||
Adding AudioX LoRA training to PrismAudio is **significantly harder** than the SelVA trainer:
|
|
||||||
|
|
||||||
| Aspect | SelVA LoRA | AudioX LoRA |
|
|
||||||
|---|---|---|
|
|
||||||
| Loss function | Single MSE velocity loss | Diffusion loss over 250-step schedule |
|
|
||||||
| Training steps needed | ~2000 steps practical | Unknown — likely much more |
|
|
||||||
| Step cost | Fast (1 velocity prediction) | Slow (full diffusion forward pass per step) |
|
|
||||||
| Existing infrastructure | Full trainer + scheduler + experiments | Nothing — would need to build from scratch |
|
|
||||||
| Noise schedule | Trivial (linear interpolation) | Cosine alpha-sigma schedule |
|
|
||||||
| Prior art for LoRA | LoRA on flow matching well-studied | Less explored; closer to Stable Diffusion LoRA |
|
|
||||||
|
|
||||||
**Conclusion:** AudioX LoRA training is feasible (it would follow SD-style LoRA with the DPM++
|
|
||||||
noise schedule) but would be a substantial new project. Not worth building until inference nodes
|
|
||||||
are stable and there is a clear use case that SelVA cannot serve.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
AudioX weights are released under **CC-BY-NC-4.0** (Creative Commons Non-Commercial).
|
|
||||||
|
|
||||||
- Free for personal use, research, and non-commercial projects
|
|
||||||
- **Cannot be used in commercial products or services** without a separate agreement
|
|
||||||
- Attribution required
|
|
||||||
- SelVA/MMAudio: MIT (no restrictions)
|
|
||||||
|
|
||||||
If PrismAudio is ever distributed as part of a commercial tool, AudioX nodes must be clearly
|
|
||||||
opt-in with a license warning, or excluded entirely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Recommendation
|
|
||||||
|
|
||||||
**Short term:** AudioX is not a replacement for SelVA for the current use case (video → custom
|
|
||||||
sound effects with LoRA fine-tuning). SelVA is faster, has full training infrastructure, and
|
|
||||||
is MIT licensed.
|
|
||||||
|
|
||||||
**When AudioX becomes worth integrating:**
|
|
||||||
- If you need to generate background music synchronized to video
|
|
||||||
- If you need audio inpainting (fill a gap in an existing audio track)
|
|
||||||
- If you need text-to-audio generation without a video input
|
|
||||||
- After verifying the CC-BY-NC-4.0 license is acceptable for your use
|
|
||||||
|
|
||||||
**Estimated integration effort for inference nodes only:** 2–3 days of work (3 new node files,
|
|
||||||
dependency management, testing). No changes to existing SelVA nodes required — they would
|
|
||||||
coexist in the same package.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- Paper: arXiv:2503.10522 — *AudioX: Diffusion Transformer for Anything-to-Audio Generation*
|
|
||||||
- GitHub: https://github.com/ZeyueT/AudioX
|
|
||||||
- Model weights: https://huggingface.co/HKUSTAudio/AudioX-MAF
|
|
||||||
- Demo: https://huggingface.co/spaces/Zeyue7/AudioX
|
|
||||||
- Project page: https://zeyuet.github.io/AudioX/
|
|
||||||
@@ -1,606 +0,0 @@
|
|||||||
# Audio Dataset Pipeline Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** Add 5 chainable ComfyUI nodes for in-memory audio dataset preprocessing: load → resample → LUFS normalize → inspect/filter → extract single item.
|
|
||||||
|
|
||||||
**Architecture:** Single new file `nodes/selva_dataset_pipeline.py` defines a custom `AUDIO_DATASET` type (list of dicts) and all 5 node classes. Nodes are stateless transforms — each takes `AUDIO_DATASET` and returns `AUDIO_DATASET`. No disk I/O except in the Loader. Register all nodes in `nodes/__init__.py`.
|
|
||||||
|
|
||||||
**Tech Stack:** `pyloudnorm` (BS.1770-4 LUFS), `soxr` (VHQ resampling), `torchaudio`, `torch`. Both confirmed present in the ComfyUI environment at `/media/p5/miniforge3/envs/latestcomfyui`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## The `AUDIO_DATASET` type
|
|
||||||
|
|
||||||
Used as the ComfyUI type string `"AUDIO_DATASET"`. At runtime it is a Python list of dicts:
|
|
||||||
|
|
||||||
```python
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"waveform": torch.Tensor, # shape [1, C, L], float32, range [-1, 1]
|
|
||||||
"sample_rate": int,
|
|
||||||
"name": str, # original filename stem, for reporting
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Create the file skeleton and AUDIO_DATASET constant
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Write the file with imports and type constant only**
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
|
|
||||||
|
|
||||||
Typical chain:
|
|
||||||
SelvaDatasetLoader
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetResampler (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetLUFSNormalizer (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetInspector (optional)
|
|
||||||
↓ AUDIO_DATASET + STRING report
|
|
||||||
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
# ComfyUI custom type name — passed between all dataset pipeline nodes
|
|
||||||
AUDIO_DATASET = "AUDIO_DATASET"
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify import works (no test framework needed — just a quick smoke check)**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /media/p5/Comfyui-Prismaudio
|
|
||||||
python3 -c "from nodes.selva_dataset_pipeline import AUDIO_DATASET; print(AUDIO_DATASET)"
|
|
||||||
```
|
|
||||||
Expected output: `AUDIO_DATASET`
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add audio dataset pipeline skeleton"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 2: SelvaDatasetLoader
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the Loader class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetLoader:
|
|
||||||
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
|
|
||||||
|
|
||||||
def load(self, folder: str):
|
|
||||||
folder = Path(folder.strip())
|
|
||||||
if not folder.exists():
|
|
||||||
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
|
|
||||||
|
|
||||||
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
|
|
||||||
if not files:
|
|
||||||
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
|
|
||||||
|
|
||||||
dataset = []
|
|
||||||
for f in sorted(files):
|
|
||||||
try:
|
|
||||||
wav, sr = torchaudio.load(str(f)) # [C, L]
|
|
||||||
wav = wav.unsqueeze(0).float() # [1, C, L]
|
|
||||||
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
|
|
||||||
return (dataset,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader
|
|
||||||
node = SelvaDatasetLoader()
|
|
||||||
ds, = node.load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
print(len(ds), 'clips', ds[0]['name'], ds[0]['waveform'].shape, ds[0]['sample_rate'])
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: prints clip count, first clip name, shape like `torch.Size([1, 2, 352800])`, sample rate.
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetLoader node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 3: SelvaDatasetResampler
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the Resampler class**
|
|
||||||
|
|
||||||
Uses `soxr` directly for VHQ quality. `soxr.resample` operates on numpy arrays, shape `[L, C]` (time-first).
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetResampler:
|
|
||||||
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_sr": ("INT", {
|
|
||||||
"default": 44100, "min": 8000, "max": 192000,
|
|
||||||
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "resample"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
|
|
||||||
|
|
||||||
def resample(self, dataset, target_sr: int):
|
|
||||||
import soxr
|
|
||||||
|
|
||||||
out = []
|
|
||||||
changed = 0
|
|
||||||
for item in dataset:
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
if sr == target_sr:
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
# soxr expects [L, C] (time-first), float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
|
|
||||||
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
|
|
||||||
changed += 1
|
|
||||||
|
|
||||||
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
|
|
||||||
return (out,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetResampler
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
ds2, = SelvaDatasetResampler().resample(ds, 44100)
|
|
||||||
print('ok', ds2[0]['sample_rate'], ds2[0]['waveform'].shape)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetResampler node (soxr VHQ)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 4: SelvaDatasetLUFSNormalizer
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the LUFS normalizer class**
|
|
||||||
|
|
||||||
`pyloudnorm.Meter` requires numpy float64 array shape `[L]` (mono) or `[L, C]` (multichannel, channels last). True peak limit applied after gain.
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetLUFSNormalizer:
|
|
||||||
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
|
|
||||||
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
out = []
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
try:
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
except Exception:
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_norm = wav * gain_linear
|
|
||||||
|
|
||||||
# True peak limit
|
|
||||||
peak = wav_norm.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_norm = wav_norm * (tp_linear / peak)
|
|
||||||
|
|
||||||
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
|
|
||||||
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetLUFSNormalizer
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
ds2, = SelvaDatasetLUFSNormalizer().normalize(ds, -23.0, -1.0)
|
|
||||||
print('ok', ds2[0]['name'], ds2[0]['waveform'].abs().max().item())
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: peak ≤ ~0.89 (≈ -1 dBTP).
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 5: SelvaDatasetInspector
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add helper functions for artifact detection**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
|
|
||||||
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
|
|
||||||
|
|
||||||
Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT.
|
|
||||||
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
|
|
||||||
"""
|
|
||||||
if sr < 32000:
|
|
||||||
return False # can't assess HF at low sample rates
|
|
||||||
|
|
||||||
n_fft = 2048
|
|
||||||
hop = 512
|
|
||||||
window = torch.hann_window(n_fft)
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
|
|
||||||
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
|
|
||||||
|
|
||||||
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
|
|
||||||
band_lo = (freqs >= 1000) & (freqs < 5000)
|
|
||||||
band_hi = (freqs >= 15000) & (freqs < 20000)
|
|
||||||
|
|
||||||
if band_hi.sum() == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
|
|
||||||
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
|
|
||||||
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
|
|
||||||
return ratio_db > 40.0
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_snr(wav: torch.Tensor) -> float:
|
|
||||||
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
frames = mono.unfold(0, 2048, 512) # [N, 2048]
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt() # [N]
|
|
||||||
p95 = torch.quantile(rms, 0.95).item()
|
|
||||||
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
|
|
||||||
return 20.0 * np.log10(p95 / p05 + 1e-8)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Add the Inspector class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetInspector:
|
|
||||||
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"skip_rejected": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "If True, flagged clips are removed from the output dataset. "
|
|
||||||
"If False, all clips pass through but the report still lists issues.",
|
|
||||||
}),
|
|
||||||
"min_snr_db": ("FLOAT", {
|
|
||||||
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
|
|
||||||
"tooltip": "Clips with estimated SNR below this value are flagged.",
|
|
||||||
}),
|
|
||||||
"check_codec_artifacts": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET, "STRING")
|
|
||||||
RETURN_NAMES = ("dataset", "report")
|
|
||||||
FUNCTION = "inspect"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Analyze each clip for clipping, low SNR, and codec artifacts. "
|
|
||||||
"Outputs a filtered AUDIO_DATASET and a text report. "
|
|
||||||
"Connect report to a ShowText node to preview in the UI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
|
|
||||||
clean = []
|
|
||||||
flagged = []
|
|
||||||
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
issues = []
|
|
||||||
|
|
||||||
# Clipping
|
|
||||||
peak = wav.abs().max().item()
|
|
||||||
if peak > 0.99:
|
|
||||||
issues.append(f"clipping (peak={peak:.3f})")
|
|
||||||
|
|
||||||
# Low SNR
|
|
||||||
snr = _estimate_snr(wav)
|
|
||||||
if snr < min_snr_db:
|
|
||||||
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
|
|
||||||
|
|
||||||
# Codec artifacts
|
|
||||||
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
|
||||||
issues.append("codec artifact (HF shelf > 15 kHz)")
|
|
||||||
|
|
||||||
if issues:
|
|
||||||
flagged.append(name)
|
|
||||||
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
|
||||||
if not skip_rejected:
|
|
||||||
clean.append(item)
|
|
||||||
else:
|
|
||||||
clean.append(item)
|
|
||||||
lines.append(f" OK {name}")
|
|
||||||
|
|
||||||
lines.append("=" * 40)
|
|
||||||
lines.append(
|
|
||||||
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
|
|
||||||
+ (" (removed)" if skip_rejected else " (kept)")
|
|
||||||
)
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(f"[DatasetInspector]\n{report}", flush=True)
|
|
||||||
return (clean, report)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetInspector
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
clean, report = SelvaDatasetInspector().inspect(ds, skip_rejected=False, min_snr_db=15.0, check_codec_artifacts=True)
|
|
||||||
print(report)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: report with per-clip OK/FLAGGED lines and summary counts.
|
|
||||||
|
|
||||||
**Step 4: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 6: SelvaDatasetItemExtractor
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the extractor class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetItemExtractor:
|
|
||||||
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
|
|
||||||
|
|
||||||
Bridges the dataset pipeline to any node that accepts a standard AUDIO
|
|
||||||
input — save audio, HF Smoother, Spectral Matcher, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0, "min": 0, "max": 9999,
|
|
||||||
"tooltip": "0-based index. Wraps around if index >= dataset length.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("audio", "name", "total")
|
|
||||||
FUNCTION = "extract"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Extract one clip from an AUDIO_DATASET by index. "
|
|
||||||
"Returns standard AUDIO (compatible with all audio nodes), "
|
|
||||||
"the clip name, and the total dataset length."
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract(self, dataset, index: int):
|
|
||||||
if not dataset:
|
|
||||||
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
|
|
||||||
idx = index % len(dataset)
|
|
||||||
item = dataset[idx]
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
print(
|
|
||||||
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
|
|
||||||
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (audio, item["name"], len(dataset))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetItemExtractor
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
audio, name, total = SelvaDatasetItemExtractor().extract(ds, 0)
|
|
||||||
print(name, total, audio['waveform'].shape, audio['sample_rate'])
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetItemExtractor node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 7: Register all nodes in __init__.py
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/__init__.py:4-25`
|
|
||||||
|
|
||||||
**Step 1: Add the 5 new entries to `_NODES`**
|
|
||||||
|
|
||||||
Add inside the `_NODES` dict, after `"SelvaDittoOptimizer"`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
|
|
||||||
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
|
|
||||||
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
|
|
||||||
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
|
|
||||||
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify registration**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
import sys; sys.path.insert(0, '/media/p5/Comfyui-Prismaudio')
|
|
||||||
from nodes import NODE_CLASS_MAPPINGS
|
|
||||||
keys = [k for k in NODE_CLASS_MAPPINGS if 'Dataset' in k]
|
|
||||||
print(keys)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: list of 5 dataset node keys.
|
|
||||||
|
|
||||||
**Step 3: Final commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/__init__.py
|
|
||||||
git commit -m "feat: register audio dataset pipeline nodes in __init__.py"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
5 nodes in `nodes/selva_dataset_pipeline.py`, all registered in `__init__.py`:
|
|
||||||
|
|
||||||
| Node | In | Out |
|
|
||||||
|------|----|-----|
|
|
||||||
| SelvaDatasetLoader | folder path | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetResampler | AUDIO_DATASET | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetLUFSNormalizer | AUDIO_DATASET | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetInspector | AUDIO_DATASET | AUDIO_DATASET + STRING |
|
|
||||||
| SelvaDatasetItemExtractor | AUDIO_DATASET + index | AUDIO + name + total |
|
|
||||||
|
|
||||||
Dependencies: `pyloudnorm`, `soxr` — both confirmed present in the ComfyUI env.
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "alpha_scale_sweep",
|
|
||||||
"description": "Fix LoRA noise contamination (flatness 0.013→0.094 at alpha=rank). Root cause: alpha=rank (scale=1.0) at high rank drowns base model priors. Testing dramatically lower alpha to nudge rather than overwrite. All runs at lr=3e-4 (best stable LR from r128_sweet_spot).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/alpha_scale_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 6000,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r16_alpha4",
|
|
||||||
"group": "conservative",
|
|
||||||
"description": "Back to basics: rank=16 alpha=4 (scale=0.25). Small adapter, gentle scale — cleanest possible LoRA signal.",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 4.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r16_alpha16",
|
|
||||||
"group": "conservative",
|
|
||||||
"description": "rank=16 alpha=16 (scale=1.0) — the original default. Reference point: is the noise issue rank-specific or universal?",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r32_alpha8",
|
|
||||||
"group": "mid",
|
|
||||||
"description": "rank=32 alpha=8 (scale=0.25). More capacity than r16 but still gentle scale.",
|
|
||||||
"rank": 32,
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r32_alpha32",
|
|
||||||
"group": "mid",
|
|
||||||
"description": "rank=32 alpha=32 (scale=1.0). Same rank, full scale — isolates whether scale or rank is causing noise.",
|
|
||||||
"rank": 32,
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha8",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=8 (scale=0.0625). High capacity, very gentle contribution — can r128 stay clean at low alpha?",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha16",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=16 (scale=0.125). Slightly more signal than alpha=8.",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha32",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=32 (scale=0.25). Same scale as r16_alpha4 and r32_alpha8 — comparable contribution across ranks.",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 32.0
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_disc_fm_retest",
|
|
||||||
"description": "Retest discriminator feature matching after bfloat16 dtype fix. Uses optimal config from overnight sweep (snake_alpha, GAFilter, lr=1e-4, phase=1.0, L2-SP=1e-3, 5000 steps).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_disc_fm_retest",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 5000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "snake_5k_control",
|
|
||||||
"description": "Control: best config from overnight sweep without discriminator. Baseline for A/B comparison."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "disc_fm_5k",
|
|
||||||
"description": "Discriminator feature matching at 5k steps. Tests if perceptual FM loss improves over mel+phase alone.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_optimized_dataset",
|
|
||||||
"description": "BigVGAN fine-tuning on optimized dataset (134 clips, 44.1kHz, LUFS-normalized). Standard mode (no LoRA) — trains decoder to faithfully reconstruct target domain audio from mel spectrograms. Uses optimal config from prior sweeps.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_optimized_dataset",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 5000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "standard_5k",
|
|
||||||
"description": "Standard mode: mel from clean FLAC → BigVGAN → reconstruct FLAC. No LoRA. Directly improves VAE roundtrip quality."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "disc_fm_5k",
|
|
||||||
"description": "Standard mode + discriminator feature matching. Tests if perceptual loss helps on clean audio reconstruction.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "standard_10k",
|
|
||||||
"description": "Extended 10k steps. More data passes on 134 clips may extract more from the optimized dataset.",
|
|
||||||
"steps": 10000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_overnight",
|
|
||||||
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 3000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "snake_3k_baseline",
|
|
||||||
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_5k",
|
|
||||||
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
|
|
||||||
"steps": 5000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_no_gafilter",
|
|
||||||
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
|
|
||||||
"use_gafilter": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_no_phase",
|
|
||||||
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
|
|
||||||
"lambda_phase": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_phase_2",
|
|
||||||
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
|
|
||||||
"lambda_phase": 2.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_lr5e-5",
|
|
||||||
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
|
|
||||||
"lr": 5e-5,
|
|
||||||
"steps": 5000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_disc_fm",
|
|
||||||
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "all_2k_l2sp1e-2",
|
|
||||||
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
|
|
||||||
"train_mode": "all_params",
|
|
||||||
"steps": 2000,
|
|
||||||
"lr": 1e-5,
|
|
||||||
"lambda_l2sp": 1e-2
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "eval_r128_candidates",
|
|
||||||
"description": "Top candidates from r128_sweet_spot. Comparing the two lowest-loss runs, the stable lr=3e-4, and the curriculum run that hit 0.161 before regressing. Baseline included as perceptual reference.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_dir": "/media/unraid/davinci/Selva/BJ/evals/r128_candidates",
|
|
||||||
"steps": 25,
|
|
||||||
"seed": 42,
|
|
||||||
"adapters": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "No LoRA — base model output for perceptual reference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_5e4_r128",
|
|
||||||
"description": "Best loss overall (0.137), still descending at step 10k",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_5e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_3e4_r256",
|
|
||||||
"description": "Tied with lr_5e4 at 0.139, higher rank — does extra capacity help perceptually?",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g4_r256_lr_3e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_3e4_r128",
|
|
||||||
"description": "Stable plateau from step 4k to 10k (0.221) — visually confirmed clean spectrograms",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_3e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_lr_3e4",
|
|
||||||
"description": "Best min loss of all (0.161 at step 6k), regressed to 0.193 after curriculum switch — curious if the early checkpoint sounds better",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_lr_3e4_step6000",
|
|
||||||
"description": "Same run at its actual best step (before regression) — compare against adapter_final to hear the regression",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_step06000.pt"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "lora_logit_cosine_combo",
|
|
||||||
"description": "Combine the two best findings from optimized dataset sweep: logit-normal timestep sampling + cosine LR schedule. Both individually outperformed baseline by large margins (56% and 68% lower loss). Tests if gains stack.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_logit_cosine_combo",
|
|
||||||
"base": {
|
|
||||||
"rank": 128,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"steps": 5000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "logit_normal_cosine",
|
|
||||||
"description": "Logit-normal timesteps + cosine LR decay. Combining the two best individual improvements.",
|
|
||||||
"timestep_mode": "logit_normal",
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "logit_normal_control",
|
|
||||||
"description": "Control: logit-normal only (constant LR). Reproduces previous winner for direct comparison.",
|
|
||||||
"timestep_mode": "logit_normal"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "lora_optimized_dataset",
|
|
||||||
"description": "LoRA training on optimized dataset (134 clips: resampled 44.1kHz, LUFS-normalized, spectral matched, HF smoothed, gain-augmented). Tests latent augmentation and schedule variants on top of known-best config (PiSSA, rank=128, lr=3e-4).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_optimized_dataset",
|
|
||||||
"base": {
|
|
||||||
"rank": 128,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"steps": 5000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "Control: known-best config (PiSSA r128 lr=3e-4) on the optimized dataset. No latent augmentation."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "latent_mixup",
|
|
||||||
"description": "Latent mixup alpha=0.4 (MusicLDM). Tests if mixing training latents reduces memorization on 134 clips.",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "latent_noise",
|
|
||||||
"description": "Latent noise sigma=0.02. Mild Gaussian noise on training latents for regularization.",
|
|
||||||
"latent_noise_sigma": 0.02
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "mixup_and_noise",
|
|
||||||
"description": "Both latent mixup (0.4) and noise (0.02). Combined regularization.",
|
|
||||||
"latent_mixup_alpha": 0.4,
|
|
||||||
"latent_noise_sigma": 0.02
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cosine_schedule",
|
|
||||||
"description": "Cosine LR decay. lr=3e-4 was stable with constant, but cosine may extract more from 5k steps.",
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cosine_mixup",
|
|
||||||
"description": "Cosine LR + latent mixup. Best regularization combo candidate.",
|
|
||||||
"lr_schedule": "cosine",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "logit_normal",
|
|
||||||
"description": "Logit-normal timestep sampling (sigma=1.0). Concentrates training near t=0.5 where flow matching is hardest.",
|
|
||||||
"timestep_mode": "logit_normal"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_mixup",
|
|
||||||
"description": "Curriculum timesteps (logit_normal first 60%, then uniform) + latent mixup. Full regularization stack.",
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "pissa_sweep",
|
|
||||||
"description": "PiSSA vs standard init ablation at rank 128. Best prior config (lr=3e-4, bs=16, 10k steps) as baseline. PiSSA starts on-manifold via SVD init — should eliminate intruder dimensions. rsLoRA stabilises scaling at high rank.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant",
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "standard_baseline",
|
|
||||||
"description": "Standard Kaiming init + classic alpha/rank scaling. Replicates best prior config for A/B comparison.",
|
|
||||||
"init_mode": "standard",
|
|
||||||
"use_rslora": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora",
|
|
||||||
"description": "PiSSA init + rsLoRA scaling. Full Tier-S config. Should start on-manifold and avoid intruder dimensions."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_classic_scale",
|
|
||||||
"description": "PiSSA init + classic alpha/rank scaling. Isolates PiSSA contribution from rsLoRA.",
|
|
||||||
"use_rslora": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "standard_rslora",
|
|
||||||
"description": "Standard init + rsLoRA only. Isolates rsLoRA contribution from PiSSA.",
|
|
||||||
"init_mode": "standard"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_lr1e-4",
|
|
||||||
"description": "PiSSA+rsLoRA at lower lr=1e-4. PiSSA starts closer to optimum — may need less aggressive lr.",
|
|
||||||
"lr": 1e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_lr5e-4",
|
|
||||||
"description": "PiSSA+rsLoRA at higher lr=5e-4. Test if on-manifold start tolerates faster learning.",
|
|
||||||
"lr": 5e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_dropout",
|
|
||||||
"description": "PiSSA+rsLoRA with dropout 0.05. Note: PiSSA forces dropout=0 (principal components should not be dropped) — this tests standard init with rsLoRA + dropout.",
|
|
||||||
"init_mode": "standard",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "r128_sweet_spot",
|
|
||||||
"description": "Find the noise-free sweet spot on rank 128. LoRA+ ratio=16 caused noise — testing higher base LR without LoRA+ as a cleaner alternative. Target loss range 0.25–0.35. Also probing rank 256 since 102GB VRAM allows it.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_2e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=2e-4. Conservative 2× step up from baseline — noise-free descent toward sweet spot.",
|
|
||||||
"lr": 2e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_3e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4. 3× baseline — landed at 0.41 on r64, should reach 0.25–0.35 on r128.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_5e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=5e-4. Aggressive but no LoRA+ B-matrix asymmetry — cleaner noise profile.",
|
|
||||||
"lr": 5e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r128_curriculum",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "Curriculum only at baseline LR. Clean slow descent — reference for what curriculum contributes alone.",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r128_lr_3e4_curriculum",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "LR=3e-4 + curriculum. Speed of higher LR with coverage of curriculum — no LoRA+.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r128_lr_3e4_curriculum_dropout",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "LR=3e-4 + curriculum + dropout=0.05. Full controlled stack without LoRA+.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r128_lora_plus_4",
|
|
||||||
"group": "lora_plus",
|
|
||||||
"description": "LoRA+ ratio=4 (lr_B=4e-4). Much more conservative than ratio=16 — tests if noise came from ratio not the technique.",
|
|
||||||
"lora_plus_ratio": 4.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_r256_baseline",
|
|
||||||
"group": "rank256",
|
|
||||||
"description": "Rank 256 at baseline LR. 102GB VRAM makes this viable — does more capacity keep helping?",
|
|
||||||
"rank": 256
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_r256_lr_3e4",
|
|
||||||
"group": "rank256",
|
|
||||||
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
|
|
||||||
"rank": 256,
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_r128_lr_2e4_cosine",
|
|
||||||
"group": "cosine",
|
|
||||||
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 6000–8000 by decaying LR to ~0 instead of staying flat.",
|
|
||||||
"lr": 2e-4,
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_r128_lr_3e4_cosine",
|
|
||||||
"group": "cosine",
|
|
||||||
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "r64_overnight",
|
|
||||||
"description": "Focused rank-64 overnight sweep. All experiments use rank 64 as base — confirmed best from tier1_thorough early results. 8000 steps to reach convergence (none converged at 4000).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r64_overnight",
|
|
||||||
"base": {
|
|
||||||
"steps": 8000,
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r64_baseline",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 64 baseline — clean reference at 8000 steps."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_baseline",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 128 — 102GB VRAM makes this free. Does doubling rank from 64 help further?",
|
|
||||||
"rank": 128
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r64_alpha_32",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Rank 64 alpha=32 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r64_alpha_16",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Rank 64 alpha=16 (scale=0.25). More aggressive scale reduction — may over-constrain.",
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r64_lora_plus",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=16. lr_B = 16 × lr_A. Faster convergence at constant step budget.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_dropout_0.05",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Dropout=0.05. Light sparsity regularisation on LoRA path.",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_dropout_0.1",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Dropout=0.1. Stronger regularisation — tests if 49 clips needs heavier constraint.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_curriculum",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Curriculum sampling: logit_normal for steps 1-4800, then uniform (arXiv:2603.12517).",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_r64_lr_low",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-5. 3× lower — checks if 1e-4 is overshooting at rank 64.",
|
|
||||||
"lr": 3e-5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_r64_lr_high",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4. 3× higher — may converge faster but risk instability.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_r64_target_full",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 64 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_r128_target_full",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 128 + full target. Maximum possible coverage with available VRAM.",
|
|
||||||
"rank": 128,
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g6_r64_full_tier1",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64: LoRA+ 16 + dropout 0.05 + curriculum. Full stack at 8000 steps.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_r64_alpha32_full",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "Rank 64 alpha=32 + all Tier 1. Best alpha scaling + best regularisation stack.",
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_r128_full_tier1",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "Rank 128 + all Tier 1. Tests if more capacity + regularisation beats rank 64 full.",
|
|
||||||
"rank": 128,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "ti_sweep_1",
|
|
||||||
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
|
||||||
"base": {
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_text": "",
|
|
||||||
"lr": 2e-4,
|
|
||||||
"n_tokens": 4,
|
|
||||||
"inject_mode": "suffix"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n4_baseline",
|
|
||||||
"group": "reference",
|
|
||||||
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n4_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n4_prefix_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
|
|
||||||
"inject_mode": "prefix"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n8_prefix_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
|
|
||||||
"n_tokens": 8,
|
|
||||||
"inject_mode": "prefix"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n4_prefix_warm_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
|
|
||||||
"inject_mode": "prefix",
|
|
||||||
"init_text": "mechanical impact sound design"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "tier1_sweep",
|
|
||||||
"description": "Ablation of Tier 1 improvements: LoRA+, dropout, curriculum sampling. Baseline = uniform, no regularisation.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "lora_sweeps/tier1_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 4000,
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 500,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "Standard LoRA — no Tier 1 changes. Reference point."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lora_plus_16",
|
|
||||||
"description": "LoRA+ only: lr_B = 16 * lr_A. Should converge faster in early steps.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "dropout_0.05",
|
|
||||||
"description": "LoRA dropout 0.05 only. Light regularisation for 49-clip dataset.",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "dropout_0.1",
|
|
||||||
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum",
|
|
||||||
"description": "Curriculum sampling only: logit_normal for steps 1-2400, then uniform. Should improve convergence vs pure uniform.",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "full_tier1",
|
|
||||||
"description": "All Tier 1 combined: LoRA+ + dropout 0.05 + curriculum.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "rank_64",
|
|
||||||
"description": "Rank 64 baseline — MMAudio LoRA guide default. More expressive adapter for 49-clip dataset.",
|
|
||||||
"rank": 64
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "tier1_thorough",
|
|
||||||
"description": "Full overnight Tier 1 ablation on 49-clip BJ dataset. 4 groups: rank, alpha, regularisation, and best combinations. ~10-12h depending on GPU.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/tier1_thorough",
|
|
||||||
"base": {
|
|
||||||
"steps": 4000,
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "g1_rank_16",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 16 baseline — reference point for all groups."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_rank_32",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 32 — midpoint. Does doubling rank improve quality without overfitting?",
|
|
||||||
"rank": 32
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_rank_64",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 64 — MMAudio LoRA guide default. Maximum expressiveness at 49 clips.",
|
|
||||||
"rank": 64
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_alpha_half_r16",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Alpha=8 with rank 16 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_alpha_half_r64",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Alpha=32 with rank 64 (scale=0.5). Best-practice scaling for high-rank adapters.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_lora_plus_4",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=4 — conservative asymmetric LR. Lower bound for the technique.",
|
|
||||||
"lora_plus_ratio": 4.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_lora_plus_16",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=16 — standard from FLUX LoRA literature. Faster early convergence.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_dropout_0.05",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA dropout 0.05 only. Light sparsity regularisation (arXiv:2404.09610).",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_dropout_0.1",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_curriculum",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Curriculum sampling only: logit_normal steps 1-2400, then uniform (arXiv:2603.12517).",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_full_r16",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 16: LoRA+ 16 + dropout 0.05 + curriculum.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_full_r64",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64 + alpha=32. Best expressiveness + best regularisation.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_lr_low",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
|
|
||||||
"lr": 3e-5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_lr_high",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g6_target_full_r16",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_target_full_r64",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_full_r64_6k",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"steps": 6000,
|
|
||||||
"save_every": 1000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "vocoder_finetune",
|
|
||||||
"description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "r128_lr_3e4_bj_vocoder",
|
|
||||||
"description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4."
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -5,36 +5,6 @@ _NODES = {
|
|||||||
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||||
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
|
||||||
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
|
||||||
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
|
|
||||||
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
|
||||||
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
|
||||||
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
|
|
||||||
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
|
|
||||||
"SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"),
|
|
||||||
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
|
|
||||||
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
|
|
||||||
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
|
|
||||||
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
|
|
||||||
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
|
|
||||||
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
|
|
||||||
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
|
|
||||||
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
|
|
||||||
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
|
|
||||||
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
|
|
||||||
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
|
|
||||||
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
|
|
||||||
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
|
|
||||||
"SelvaDatasetCompressor": (".selva_dataset_pipeline", "SelvaDatasetCompressor", "SelVA Dataset Compressor"),
|
|
||||||
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
|
|
||||||
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
|
|
||||||
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
|
|
||||||
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
|
|
||||||
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
|
|
||||||
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
|
|
||||||
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
|
|
||||||
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
"""SelVA Activation Steering Extractor.
|
|
||||||
|
|
||||||
Computes per-block steering vectors by running the frozen generator on the
|
|
||||||
training dataset and recording how target style's conditioning shifts the DiT hidden
|
|
||||||
states vs. empty/unconditional conditioning.
|
|
||||||
|
|
||||||
For each block i:
|
|
||||||
steering[i] = mean(latent_hidden | target style conditions)
|
|
||||||
- mean(latent_hidden | empty conditions)
|
|
||||||
|
|
||||||
The resulting vectors are injected at inference time (via SelVA Sampler's
|
|
||||||
steering_strength input) to nudge the denoising trajectory toward target style's
|
|
||||||
activation patterns without modifying any model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_lora_trainer import _prepare_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_activations(generator, conditions, latent, t_tensor):
|
|
||||||
"""Run one predict_flow call, collecting latent hidden states per block.
|
|
||||||
|
|
||||||
Returns a list of [seq, hidden_dim] float32 CPU tensors,
|
|
||||||
one per block (joint_blocks first, then fused_blocks).
|
|
||||||
"""
|
|
||||||
activations = []
|
|
||||||
|
|
||||||
def make_hook(is_joint):
|
|
||||||
def hook(module, input, output):
|
|
||||||
h = output[0] if is_joint else output
|
|
||||||
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
|
|
||||||
return hook
|
|
||||||
|
|
||||||
handles = []
|
|
||||||
for block in generator.joint_blocks:
|
|
||||||
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
|
|
||||||
for block in generator.fused_blocks:
|
|
||||||
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
generator.predict_flow(latent, t_tensor, conditions)
|
|
||||||
finally:
|
|
||||||
for h in handles:
|
|
||||||
h.remove()
|
|
||||||
|
|
||||||
return activations # list of n_blocks tensors [seq, hidden]
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaActivationSteeringExtractor:
|
|
||||||
"""Computes activation steering vectors from a training dataset.
|
|
||||||
|
|
||||||
Runs the frozen generator on N clips at random timesteps with both
|
|
||||||
target style-conditioned and empty-conditioned inputs, then saves the mean
|
|
||||||
difference per DiT block to a .pt file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "extract"
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("steering_path",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Computes per-block activation steering vectors: mean(target style activations) − "
|
|
||||||
"mean(empty activations) at each DiT block. Load the result with "
|
|
||||||
"SelVA Activation Steering Loader and connect to the Sampler."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"data_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
|
|
||||||
}),
|
|
||||||
"output_path": ("STRING", {
|
|
||||||
"default": "steering_vectors.pt",
|
|
||||||
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
"n_samples": ("INT", {
|
|
||||||
"default": 16, "min": 1, "max": 256,
|
|
||||||
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract(self, model, data_dir, output_path, n_samples, seed):
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
|
|
||||||
data_dir = Path(data_dir.strip())
|
|
||||||
if not data_dir.is_absolute():
|
|
||||||
data_dir = Path(folder_paths.models_dir) / data_dir
|
|
||||||
if not data_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
|
|
||||||
|
|
||||||
out_path = Path(output_path.strip())
|
|
||||||
if not out_path.is_absolute():
|
|
||||||
out_path = Path(folder_paths.get_output_directory()) / out_path
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
|
|
||||||
print(f"[Steering] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[Steering] output = {out_path}\n", flush=True)
|
|
||||||
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
generator = model["generator"]
|
|
||||||
generator.eval()
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
indices = random.choices(range(len(dataset)), k=n_samples)
|
|
||||||
|
|
||||||
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
|
|
||||||
style_sums = [None] * n_blocks
|
|
||||||
empty_sums = [None] * n_blocks
|
|
||||||
counts = [0] * n_blocks
|
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(n_samples)
|
|
||||||
|
|
||||||
for sample_i, clip_idx in enumerate(indices):
|
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
|
||||||
|
|
||||||
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
|
|
||||||
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
|
|
||||||
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
|
|
||||||
|
|
||||||
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
|
|
||||||
clip_latent_seq_len = x1_cpu.shape[1]
|
|
||||||
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=clip_latent_seq_len,
|
|
||||||
clip_seq_len=clip_f.shape[1],
|
|
||||||
sync_seq_len=sync_f.shape[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
|
||||||
empty_conditions = generator.get_empty_conditions(bs=1)
|
|
||||||
|
|
||||||
# Random timestep and noise latent for this clip
|
|
||||||
t_val = torch.rand(1).item()
|
|
||||||
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
|
|
||||||
latent = torch.randn(
|
|
||||||
1, clip_latent_seq_len, generator.latent_dim,
|
|
||||||
device=device, dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
|
|
||||||
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
|
|
||||||
|
|
||||||
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
|
|
||||||
if style_sums[i] is None:
|
|
||||||
style_sums[i] = st.clone()
|
|
||||||
empty_sums[i] = em.clone()
|
|
||||||
else:
|
|
||||||
style_sums[i] += st
|
|
||||||
empty_sums[i] += em
|
|
||||||
counts[i] += 1
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
|
|
||||||
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
|
|
||||||
|
|
||||||
# Steering vector per block: mean(target style) - mean(empty)
|
|
||||||
steering_vectors = []
|
|
||||||
for i in range(n_blocks):
|
|
||||||
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
|
|
||||||
steering_vectors.append(vec)
|
|
||||||
|
|
||||||
norm = vec.norm().item()
|
|
||||||
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
|
|
||||||
|
|
||||||
n_joint = len(generator.joint_blocks)
|
|
||||||
payload = {
|
|
||||||
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
|
|
||||||
"n_blocks": n_blocks,
|
|
||||||
"n_joint": n_joint,
|
|
||||||
"n_fused": len(generator.fused_blocks),
|
|
||||||
"latent_seq_len": seq_cfg.latent_seq_len,
|
|
||||||
"n_samples": n_samples,
|
|
||||||
"seed": seed,
|
|
||||||
"mode": model["mode"],
|
|
||||||
"variant": model["variant"],
|
|
||||||
}
|
|
||||||
torch.save(payload, str(out_path))
|
|
||||||
print(f"\n[Steering] Saved: {out_path}", flush=True)
|
|
||||||
|
|
||||||
soft_empty_cache()
|
|
||||||
return (str(out_path),)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""SelVA Activation Steering Loader.
|
|
||||||
|
|
||||||
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
|
|
||||||
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaActivationSteeringLoader:
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "load"
|
|
||||||
RETURN_TYPES = ("STEERING_VECTORS",)
|
|
||||||
RETURN_NAMES = ("steering_vectors",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads activation steering vectors from a .pt file produced by "
|
|
||||||
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
|
|
||||||
"denoising toward the target activation patterns."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "steering_vectors.pt",
|
|
||||||
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def load(self, path):
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[Steering] File not found: {p}")
|
|
||||||
|
|
||||||
payload = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
n_blocks = payload["n_blocks"]
|
|
||||||
n_joint = payload["n_joint"]
|
|
||||||
n_fused = payload["n_fused"]
|
|
||||||
n_vecs = len(payload["steering_vectors"])
|
|
||||||
|
|
||||||
print(f"[Steering] Loaded: {p}", flush=True)
|
|
||||||
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
|
|
||||||
f"latent_seq_len={payload['latent_seq_len']} "
|
|
||||||
f"n_samples={payload['n_samples']}", flush=True)
|
|
||||||
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
|
|
||||||
|
|
||||||
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
|
|
||||||
mean_norm = sum(norms) / len(norms)
|
|
||||||
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
|
|
||||||
|
|
||||||
return (payload,)
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
"""SelVA Audio Post-Processing nodes.
|
|
||||||
|
|
||||||
Post-generation enhancement applied to standard AUDIO outputs:
|
|
||||||
SelvaHarmonicExciter — multi-band harmonic exciter (HPF → tanh → mix)
|
|
||||||
SelvaOutputNormalizer — LUFS normalization + true peak limiting
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaHarmonicExciter:
|
|
||||||
"""Multi-band harmonic exciter for post-generation enhancement.
|
|
||||||
|
|
||||||
Isolates high-frequency content above a cutoff, applies tanh saturation
|
|
||||||
to generate 2nd/3rd harmonics, then mixes back with the dry signal.
|
|
||||||
Restores harmonic richness lost during BigVGAN vocoder reconstruction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 3000.0, "min": 500.0, "max": 16000.0, "step": 100.0,
|
|
||||||
"tooltip": "Highpass cutoff frequency in Hz. Only content above this is excited. "
|
|
||||||
"3000 Hz targets the upper harmonics BigVGAN tends to smear.",
|
|
||||||
}),
|
|
||||||
"drive": ("FLOAT", {
|
|
||||||
"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.5,
|
|
||||||
"tooltip": "Saturation drive. Higher = more harmonics generated. "
|
|
||||||
"2-3 is subtle, 5+ is aggressive.",
|
|
||||||
}),
|
|
||||||
"mix": ("FLOAT", {
|
|
||||||
"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Wet/dry blend. 0.1-0.2 is subtle enhancement, "
|
|
||||||
"0.5+ is aggressive harmonic addition.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "excite"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Multi-band harmonic exciter. Applies tanh saturation to the high-frequency band "
|
|
||||||
"to restore harmonics lost during BigVGAN vocoder reconstruction. "
|
|
||||||
"Uses pedalboard.HighpassFilter for band isolation."
|
|
||||||
)
|
|
||||||
|
|
||||||
def excite(self, audio, cutoff_hz: float, drive: float, mix: float):
|
|
||||||
from pedalboard import Pedalboard, HighpassFilter
|
|
||||||
|
|
||||||
wav = audio["waveform"][0] # [C, T]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
wav_np = wav.float().numpy() # [C, T]
|
|
||||||
|
|
||||||
# Isolate HF band
|
|
||||||
board = Pedalboard([HighpassFilter(cutoff_frequency_hz=cutoff_hz)])
|
|
||||||
hf = board(wav_np, sr) # [C, T]
|
|
||||||
|
|
||||||
# Tanh saturation — normalize by drive so output stays in [-1, 1]
|
|
||||||
excited = np.tanh(hf * drive) / max(drive, 1.0)
|
|
||||||
|
|
||||||
# Mix back with dry
|
|
||||||
mixed = wav_np + mix * excited
|
|
||||||
|
|
||||||
# Soft clip to prevent going over
|
|
||||||
mixed = np.tanh(mixed)
|
|
||||||
|
|
||||||
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, T]
|
|
||||||
print(
|
|
||||||
f"[HarmonicExciter] cutoff={cutoff_hz}Hz drive={drive} mix={mix:.0%}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return ({"waveform": wav_out, "sample_rate": sr},)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaOutputNormalizer:
|
|
||||||
"""Normalize generated audio to a target LUFS level with true peak limiting.
|
|
||||||
|
|
||||||
Apply as the final node before saving — brings generated audio to a
|
|
||||||
consistent loudness target regardless of input video loudness variation.
|
|
||||||
Uses pyloudnorm (BS.1770-4).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -14.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. "
|
|
||||||
"-14 LUFS for streaming (Spotify/YouTube), "
|
|
||||||
"-9 to -7 for production masters.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize output audio to a target LUFS level (BS.1770-4) with true peak limiting. "
|
|
||||||
"Apply as the last node before saving. Uses pyloudnorm."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, audio, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
wav = audio["waveform"][0] # [C, T]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [T, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [T] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
print("[OutputNormalizer] Could not measure loudness — clip too short or silent. Passing through.", flush=True)
|
|
||||||
return (audio,)
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_out = wav * gain_linear
|
|
||||||
|
|
||||||
peak = wav_out.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_out = wav_out * (tp_linear / peak)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[OutputNormalizer] {loudness:.1f} LUFS → {target_lufs} LUFS "
|
|
||||||
f"gain={gain_db:+.1f}dB TP={true_peak_dbtp}dBTP",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return ({"waveform": wav_out.unsqueeze(0), "sample_rate": sr},)
|
|
||||||
@@ -1,293 +0,0 @@
|
|||||||
"""SelVA Audio Preprocessors — condition training clips for codec compatibility.
|
|
||||||
|
|
||||||
Two nodes that reduce the domain mismatch between custom training audio and the
|
|
||||||
MMAudio VAE's expected spectral distribution, improving LoRA training quality:
|
|
||||||
|
|
||||||
SelvaHfSmoother — soft low-pass blend to attenuate extreme HF content
|
|
||||||
SelvaSpectralMatcher — adaptive per-band EQ toward the codec's training distribution
|
|
||||||
|
|
||||||
Root cause they address: MMAudio was trained on natural sounds (speech, foley, env)
|
|
||||||
with limited engineered HF content. The BigVGANv2 vocoder (frozen, pre-trained) handles
|
|
||||||
the codec's HF reconstruction poorly for sound design / music training clips, because
|
|
||||||
those clips land in a latent-space region the vocoder never saw during training.
|
|
||||||
|
|
||||||
Recommended order: SpectralMatcher → HfSmoother → feature extraction → LoRA training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio.functional as AF
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
# ── Mel filterbank (same algorithm as selva_core/ext/mel_converter.py) ────────
|
|
||||||
|
|
||||||
def _mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
|
||||||
fmin: float, fmax: float) -> torch.Tensor:
|
|
||||||
"""Returns mel filterbank matrix [n_mels, n_fft//2+1]."""
|
|
||||||
def hz_to_mel(f):
|
|
||||||
return 2595.0 * np.log10(1.0 + np.asarray(f) / 700.0)
|
|
||||||
|
|
||||||
def mel_to_hz(m):
|
|
||||||
return 700.0 * (10.0 ** (np.asarray(m) / 2595.0) - 1.0)
|
|
||||||
|
|
||||||
n_freqs = n_fft // 2 + 1
|
|
||||||
fft_freqs = np.linspace(0.0, sr / 2.0, n_freqs)
|
|
||||||
mel_pts = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
|
||||||
hz_pts = mel_to_hz(mel_pts)
|
|
||||||
|
|
||||||
fb = np.zeros((n_mels, n_freqs), dtype=np.float32)
|
|
||||||
for m in range(1, n_mels + 1):
|
|
||||||
lo, mid, hi = hz_pts[m - 1], hz_pts[m], hz_pts[m + 1]
|
|
||||||
up = (fft_freqs - lo) / (mid - lo + 1e-12)
|
|
||||||
down = (hi - fft_freqs) / (hi - mid + 1e-12)
|
|
||||||
fb[m - 1] = np.maximum(0.0, np.minimum(up, down))
|
|
||||||
return torch.from_numpy(fb)
|
|
||||||
|
|
||||||
|
|
||||||
# ── VAE target log-mel means (source: selva_core/ext/autoencoder/vae.py) ──────
|
|
||||||
# These are the per-band expected log-mel energy means from MMAudio's training data.
|
|
||||||
# Used as the spectral matching target: clips are EQ'd to match this profile.
|
|
||||||
|
|
||||||
_TARGET_MEAN_80D = [
|
|
||||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439,
|
|
||||||
-1.2922, -1.2927, -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912,
|
|
||||||
-1.4313, -1.4152, -1.4527, -1.4728, -1.4568, -1.5101, -1.5051, -1.5172,
|
|
||||||
-1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, -1.6081, -1.6331,
|
|
||||||
-1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
|
||||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377,
|
|
||||||
-1.8417, -1.8643, -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673,
|
|
||||||
-1.9824, -2.0042, -2.0215, -2.0436, -2.0766, -2.1064, -2.1418, -2.1855,
|
|
||||||
-2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, -2.4659, -2.5072,
|
|
||||||
-2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673,
|
|
||||||
]
|
|
||||||
|
|
||||||
_TARGET_MEAN_128D = [
|
|
||||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006,
|
|
||||||
-2.2357, -2.4597, -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047,
|
|
||||||
-2.7483, -2.5926, -2.7462, -2.7033, -2.7386, -2.8112, -2.7502, -2.9594,
|
|
||||||
-2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, -3.1191, -2.9893,
|
|
||||||
-3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
|
||||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509,
|
|
||||||
-3.5089, -3.4647, -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747,
|
|
||||||
-3.7072, -3.7279, -3.7283, -3.7795, -3.8259, -3.8447, -3.8663, -3.9182,
|
|
||||||
-3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, -4.1488, -4.1874,
|
|
||||||
-4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
|
||||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053,
|
|
||||||
-5.4927, -5.5712, -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103,
|
|
||||||
-6.0955, -6.1673, -6.2362, -6.3120, -6.3926, -6.4797, -6.5565, -6.6511,
|
|
||||||
-6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, -7.6136, -7.7469,
|
|
||||||
-7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
|
||||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861,
|
|
||||||
]
|
|
||||||
|
|
||||||
_MEL_CONFIGS = {
|
|
||||||
"16k": dict(sr=16_000, n_fft=1024, n_mels=80, hop=256, fmin=0, fmax=8_000,
|
|
||||||
target=_TARGET_MEAN_80D, log10=True),
|
|
||||||
"44k": dict(sr=44_100, n_fft=2048, n_mels=128, hop=512, fmin=0, fmax=22_050,
|
|
||||||
target=_TARGET_MEAN_128D, log10=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Node 1: HF Smoother ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaHfSmoother:
|
|
||||||
"""Soft high-frequency attenuation for LoRA training clip preprocessing.
|
|
||||||
|
|
||||||
Blends a low-pass filtered copy of the audio with the original. Attenuates
|
|
||||||
the extreme HF content common in engineered sound design that the BigVGANv2
|
|
||||||
vocoder handles poorly, bringing the clip closer to the spectral region the
|
|
||||||
MMAudio codec was trained on (natural sounds with limited HF energy).
|
|
||||||
|
|
||||||
A blend of 0.7 at 12 kHz is a transparent starting point — audible only on
|
|
||||||
close comparison. Increase blend or lower cutoff if roundtrip quality is still
|
|
||||||
poor after spectral matching.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
|
|
||||||
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
|
|
||||||
}),
|
|
||||||
"blend": ("FLOAT", {
|
|
||||||
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = original, 1 = fully filtered. 0.7 is a transparent starting point.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Blends a low-pass filtered version of the audio with the original to gently attenuate "
|
|
||||||
"high-frequency content that the SelVA codec handles poorly. "
|
|
||||||
"Use before feature extraction to improve LoRA training targets. "
|
|
||||||
"Run after SelVA Spectral Matcher for best results."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, audio, cutoff_hz: float, blend: float):
|
|
||||||
waveform = audio["waveform"].float() # [1, C, L]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
filtered = AF.lowpass_biquad(waveform, sr, cutoff_hz)
|
|
||||||
out = blend * filtered + (1.0 - blend) * waveform
|
|
||||||
|
|
||||||
# Preserve RMS level — LPF removes energy, keep the clip at its original loudness
|
|
||||||
rms_in = waveform.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
rms_out = out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
out = out * (rms_in / rms_out)
|
|
||||||
|
|
||||||
peak = out.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
out = out / peak
|
|
||||||
|
|
||||||
print(f"[HF Smoother] cutoff={cutoff_hz:.0f} Hz blend={blend:.2f} "
|
|
||||||
f"rms={rms_in:.4f}→{out.pow(2).mean().sqrt():.4f} "
|
|
||||||
f"peak={out.abs().max():.4f}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": sr},)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Node 2: Spectral Matcher ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaSpectralMatcher:
|
|
||||||
"""Adaptive per-band EQ toward the SelVA VAE's expected spectral distribution.
|
|
||||||
|
|
||||||
Computes the log-mel energy profile of the clip and compares it to the per-band
|
|
||||||
means stored in the VAE's normalization buffers (the statistics MMAudio was trained
|
|
||||||
on). Applies a smooth frequency-domain gain correction so the clip's spectral shape
|
|
||||||
matches what the codec expects, improving encode→decode roundtrip quality and
|
|
||||||
therefore LoRA training target quality.
|
|
||||||
|
|
||||||
The correction is additive in log space (multiplicative in linear), so it only
|
|
||||||
changes spectral balance — not the waveform's timing or phase structure.
|
|
||||||
|
|
||||||
max_gain_db clamps the correction to prevent extreme boosts on very quiet bands.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"mode": (["44k", "16k"], {
|
|
||||||
"tooltip": "Must match the SelVA model you are training. "
|
|
||||||
"44k = large model, 16k = small model.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = no correction, 1 = full match to VAE distribution. "
|
|
||||||
"0.8 is a good starting point.",
|
|
||||||
}),
|
|
||||||
"max_gain_db": ("FLOAT", {
|
|
||||||
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
|
|
||||||
"tooltip": "Clamps per-band gain to ±dB. Prevents extreme boosts on "
|
|
||||||
"very quiet frequency bands. 12 dB is conservative.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Applies a smooth per-band gain correction to bring the audio's spectral profile "
|
|
||||||
"in line with the MMAudio VAE's expected distribution, derived from the per-band "
|
|
||||||
"normalization statistics baked into the VAE weights. "
|
|
||||||
"Use before feature extraction to improve LoRA training target quality. "
|
|
||||||
"Run before SelVA HF Smoother."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, audio, mode: str, strength: float, max_gain_db: float):
|
|
||||||
cfg = _MEL_CONFIGS[mode]
|
|
||||||
waveform = audio["waveform"].float() # [1, C, L]
|
|
||||||
sr_in = audio["sample_rate"]
|
|
||||||
sr_tgt = cfg["sr"]
|
|
||||||
n_fft = cfg["n_fft"]
|
|
||||||
hop = cfg["hop"]
|
|
||||||
|
|
||||||
# ── flatten to mono and resample if needed ────────────────────────────
|
|
||||||
wav = waveform[0].mean(0) # [L]
|
|
||||||
if sr_in != sr_tgt:
|
|
||||||
wav = AF.resample(wav.unsqueeze(0), sr_in, sr_tgt).squeeze(0)
|
|
||||||
|
|
||||||
device = wav.device
|
|
||||||
window = torch.hann_window(n_fft, device=device)
|
|
||||||
|
|
||||||
# ── STFT ──────────────────────────────────────────────────────────────
|
|
||||||
stft = torch.stft(wav, n_fft, hop_length=hop, win_length=n_fft,
|
|
||||||
window=window, center=True, return_complex=True) # [n_freqs, T]
|
|
||||||
mag = stft.abs() # [n_freqs, T]
|
|
||||||
|
|
||||||
# ── current log-mel mean per band ─────────────────────────────────────
|
|
||||||
fb = _mel_filterbank(sr_tgt, n_fft, cfg["n_mels"],
|
|
||||||
cfg["fmin"], cfg["fmax"]).to(device) # [n_mels, n_freqs]
|
|
||||||
|
|
||||||
mel_mag = torch.matmul(fb, mag).clamp(min=1e-5) # [n_mels, T]
|
|
||||||
if cfg["log10"]:
|
|
||||||
mel_log = torch.log10(mel_mag)
|
|
||||||
else:
|
|
||||||
mel_log = torch.log(mel_mag)
|
|
||||||
|
|
||||||
current_mean = mel_log.mean(dim=-1) # [n_mels]
|
|
||||||
target_mean = torch.tensor(cfg["target"], device=device) # [n_mels]
|
|
||||||
|
|
||||||
# ── per-mel-band gain (log space) ─────────────────────────────────────
|
|
||||||
mel_gain = (target_mean - current_mean) * strength # [n_mels]
|
|
||||||
|
|
||||||
# Clamp to ±max_gain_db
|
|
||||||
if cfg["log10"]:
|
|
||||||
max_log = max_gain_db / 20.0 # log10: 20 log10 = dB
|
|
||||||
else:
|
|
||||||
max_log = max_gain_db / 8.6859 # ln: 20 * log10(e) ≈ 8.686
|
|
||||||
mel_gain = mel_gain.clamp(-max_log, max_log)
|
|
||||||
|
|
||||||
# ── map mel gains → STFT frequency bins (weighted average) ────────────
|
|
||||||
fb_sum = fb.sum(0).clamp(min=1e-8) # [n_freqs]
|
|
||||||
freq_gain = (mel_gain @ fb) / fb_sum # [n_freqs]
|
|
||||||
|
|
||||||
if cfg["log10"]:
|
|
||||||
linear_gain = 10.0 ** freq_gain # [n_freqs]
|
|
||||||
else:
|
|
||||||
linear_gain = torch.exp(freq_gain) # [n_freqs]
|
|
||||||
|
|
||||||
# ── apply gain in frequency domain and reconstruct ───────────────────
|
|
||||||
stft_out = stft * linear_gain.unsqueeze(-1) # [n_freqs, T]
|
|
||||||
wav_out = torch.istft(stft_out, n_fft, hop_length=hop, win_length=n_fft,
|
|
||||||
window=window, center=True,
|
|
||||||
length=wav.shape[0]) # [L]
|
|
||||||
|
|
||||||
# ── resample back to original sr ──────────────────────────────────────
|
|
||||||
if sr_in != sr_tgt:
|
|
||||||
wav_out = AF.resample(wav_out.unsqueeze(0), sr_tgt, sr_in).squeeze(0)
|
|
||||||
|
|
||||||
# ── preserve original RMS level ───────────────────────────────────────
|
|
||||||
rms_in = wav.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
rms_out = wav_out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
wav_out = wav_out * (rms_in / rms_out)
|
|
||||||
|
|
||||||
peak = wav_out.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
wav_out = wav_out / peak
|
|
||||||
|
|
||||||
# ── reshape to match input layout [1, C, L] ───────────────────────────
|
|
||||||
out = wav_out.unsqueeze(0).unsqueeze(0)
|
|
||||||
if waveform.shape[1] > 1:
|
|
||||||
out = out.expand(-1, waveform.shape[1], -1).clone()
|
|
||||||
|
|
||||||
gain_db_range = (
|
|
||||||
20.0 * torch.log10(linear_gain.clamp(min=1e-8))
|
|
||||||
)
|
|
||||||
print(f"[Spectral Matcher] mode={mode} strength={strength:.2f} "
|
|
||||||
f"gain [{gain_db_range.min():.1f}, {gain_db_range.max():.1f}] dB "
|
|
||||||
f"rms={rms_in:.4f}→{out.pow(2).mean().sqrt():.4f}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": sr_in},)
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""SelVA BigVGAN Loader.
|
|
||||||
|
|
||||||
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
|
|
||||||
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
|
|
||||||
|
|
||||||
The model is modified in-place so ComfyUI's model cache is updated — no need to
|
|
||||||
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
from .selva_bigvgan_trainer import inject_gafilters
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaBigvganLoader:
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "load"
|
|
||||||
RETURN_TYPES = ("SELVA_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
|
|
||||||
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
|
|
||||||
"Supports both 16k and 44k models. "
|
|
||||||
"Connect the output to SelVA Sampler instead of the base model loader."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "bigvgan_bj.pt",
|
|
||||||
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
|
|
||||||
"Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def load(self, model, path):
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
|
|
||||||
|
|
||||||
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
if "generator" not in ckpt:
|
|
||||||
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
|
|
||||||
|
|
||||||
mode = model["mode"]
|
|
||||||
if mode == "16k":
|
|
||||||
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
|
|
||||||
elif mode == "44k":
|
|
||||||
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
|
|
||||||
else:
|
|
||||||
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
|
||||||
|
|
||||||
# Remember device before injecting new modules (which default to CPU)
|
|
||||||
target_device = next(vocoder.parameters()).device
|
|
||||||
|
|
||||||
if ckpt.get("has_gafilter", False):
|
|
||||||
kernel_size = ckpt.get("gafilter_kernel_size", 9)
|
|
||||||
n_gaf = inject_gafilters(vocoder, kernel_size)
|
|
||||||
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
|
|
||||||
|
|
||||||
vocoder.load_state_dict(ckpt["generator"])
|
|
||||||
vocoder.to(target_device)
|
|
||||||
vocoder.eval()
|
|
||||||
|
|
||||||
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
|
|
||||||
return (model,)
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
Audio clips are loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "bigvgan_sweep",
|
|
||||||
"description": "optional note",
|
|
||||||
"data_dir": "/path/to/audio/clips",
|
|
||||||
"output_root": "/path/to/output",
|
|
||||||
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "..."},
|
|
||||||
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_bigvgan_trainer import (
|
|
||||||
_do_train,
|
|
||||||
_pregenerate_lora_mels,
|
|
||||||
_load_wav,
|
|
||||||
)
|
|
||||||
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
|
|
||||||
from .selva_lora_scheduler import (
|
|
||||||
_get_system_info,
|
|
||||||
_resolve_path,
|
|
||||||
_draw_comparison_curves,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 2000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 4,
|
|
||||||
"segment_seconds": 2.0,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": True,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 500,
|
|
||||||
"seed": 42,
|
|
||||||
"discriminator_path": "",
|
|
||||||
"lora_adapter": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
"""Merge param defaults + file base + experiment overrides."""
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_training_log(log_path: Path) -> list:
|
|
||||||
"""Parse BigVGAN training CSV → list of total_loss values."""
|
|
||||||
losses = []
|
|
||||||
if not log_path.exists():
|
|
||||||
return losses
|
|
||||||
try:
|
|
||||||
with open(log_path) as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for row in reader:
|
|
||||||
losses.append(float(row["total_loss"]))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return losses
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
total_steps: int) -> dict:
|
|
||||||
"""Build {step: loss} at each save_every boundary.
|
|
||||||
|
|
||||||
Uses round-to-nearest to handle log_interval that doesn't divide
|
|
||||||
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
for target in range(save_every, total_steps + 1, save_every):
|
|
||||||
# loss_history[i] = loss at step (i+1)*log_interval
|
|
||||||
idx = round(target / log_interval) - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaBigvganScheduler:
|
|
||||||
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
|
|
||||||
|
|
||||||
Audio clips are loaded once and reused across all experiments. Each experiment
|
|
||||||
deep-copies the vocoder and trains independently. Results are written to
|
|
||||||
`experiment_summary.json` after every completed run so partial results are
|
|
||||||
preserved if the sweep is interrupted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — share this file to compare runs.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
|
|
||||||
"Audio clips are loaded once and reused across all experiments. "
|
|
||||||
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "bigvgan_experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"models directory; absolute paths are used as-is."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
mode = model["mode"]
|
|
||||||
dtype = model["dtype"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
mel_converter = feature_utils.mel_converter
|
|
||||||
strategy = model["strategy"]
|
|
||||||
|
|
||||||
if mode == "16k":
|
|
||||||
original_vocoder = feature_utils.tod.vocoder.vocoder
|
|
||||||
sample_rate = 16_000
|
|
||||||
elif mode == "44k":
|
|
||||||
original_vocoder = feature_utils.tod.vocoder
|
|
||||||
sample_rate = 44_100
|
|
||||||
else:
|
|
||||||
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
|
|
||||||
|
|
||||||
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[BigVGAN Scheduler] {description}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load audio clips once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Find minimum segment length across all experiments so we load enough
|
|
||||||
min_segment_seconds = float("inf")
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
|
|
||||||
min_segment_samples = int(min_segment_seconds * sample_rate)
|
|
||||||
|
|
||||||
audio_files = []
|
|
||||||
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
|
|
||||||
audio_files.extend(data_dir.rglob(ext))
|
|
||||||
if not audio_files:
|
|
||||||
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
|
|
||||||
|
|
||||||
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
|
|
||||||
clips = []
|
|
||||||
for af in audio_files:
|
|
||||||
try:
|
|
||||||
wav, sr = _load_wav(af)
|
|
||||||
if wav.shape[0] > 1:
|
|
||||||
wav = wav.mean(0, keepdim=True)
|
|
||||||
if sr != sample_rate:
|
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
|
||||||
wav = wav.squeeze(0) # [L]
|
|
||||||
if wav.shape[0] >= min_segment_samples:
|
|
||||||
clips.append(wav.cpu())
|
|
||||||
else:
|
|
||||||
print(f" [BigVGAN Scheduler] Skip {af.name}: "
|
|
||||||
f"shorter than {min_segment_seconds}s", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
if not clips:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
|
|
||||||
)
|
|
||||||
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Offload unused components to free VRAM
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comfy.model_management.unload_all_models()
|
|
||||||
feature_utils.to("cpu")
|
|
||||||
if "generator" in model:
|
|
||||||
model["generator"].to("cpu")
|
|
||||||
if "video_enc" in model:
|
|
||||||
model["video_enc"].to("cpu")
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
text_clip_cache = {}
|
|
||||||
any_lora = any(
|
|
||||||
_merge_config(base_cfg, exp).get("lora_adapter", "")
|
|
||||||
for exp in spec["experiments"]
|
|
||||||
)
|
|
||||||
if any_lora:
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
if npz_files:
|
|
||||||
prompt_map = {}
|
|
||||||
prompts_file = data_dir / "prompts.txt"
|
|
||||||
if prompts_file.exists():
|
|
||||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
if "|" in line:
|
|
||||||
fname, prompt = line.split("|", 1)
|
|
||||||
prompt_map[fname.strip()] = prompt.strip()
|
|
||||||
default_prompt = data_dir.name
|
|
||||||
|
|
||||||
clip_model = feature_utils.clip_model
|
|
||||||
if clip_model is not None:
|
|
||||||
clip_model.to(device)
|
|
||||||
try:
|
|
||||||
for npz_path in npz_files:
|
|
||||||
data = dict(np.load(str(npz_path), allow_pickle=False))
|
|
||||||
prompt = prompt_map.get(
|
|
||||||
npz_path.name, data.get("prompt", default_prompt)
|
|
||||||
)
|
|
||||||
if isinstance(prompt, np.ndarray):
|
|
||||||
prompt = str(prompt)
|
|
||||||
tc = feature_utils.encode_text_clip([prompt])
|
|
||||||
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
|
|
||||||
finally:
|
|
||||||
if clip_model is not None:
|
|
||||||
clip_model.to("cpu")
|
|
||||||
soft_empty_cache()
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
|
|
||||||
f"CLIP embeddings", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Build or restore the summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
lh = rec["results"].get("loss_history", [])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": lh,
|
|
||||||
"log_interval": rec["results"].get("log_interval", 100),
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[BigVGAN Scheduler] Resuming — skipping "
|
|
||||||
f"{len(completed_ids)} completed: "
|
|
||||||
f"{sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[BigVGAN Scheduler] Could not read existing summary "
|
|
||||||
f"({e}) — starting fresh", flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": len(clips),
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(
|
|
||||||
json.dumps(summary, indent=2), encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Compute total steps for progress bar
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
total_steps = 0
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
if exp["id"] not in completed_ids:
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
total_steps += int(cfg.get("steps", 2000))
|
|
||||||
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 8. Run experiments in a worker thread
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
|
|
||||||
# inside torch.inference_mode(). inference_mode is thread-local — a
|
|
||||||
# new thread starts with it OFF, so all tensor operations produce
|
|
||||||
# normal autograd-compatible tensors.
|
|
||||||
_exc = [None]
|
|
||||||
|
|
||||||
def _worker():
|
|
||||||
try:
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
|
|
||||||
f"(already completed)", flush=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
# ── Extract experiment parameters ────────────────────
|
|
||||||
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
|
|
||||||
exp_steps = int(cfg.get("steps", 2000))
|
|
||||||
exp_lr = float(cfg.get("lr", 1e-4))
|
|
||||||
exp_bs = int(cfg.get("batch_size", 4))
|
|
||||||
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
|
|
||||||
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
|
|
||||||
exp_gafilter = bool(cfg.get("use_gafilter", True))
|
|
||||||
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
|
|
||||||
exp_phase = float(cfg.get("lambda_phase", 1.0))
|
|
||||||
exp_save = int(cfg.get("save_every", 500))
|
|
||||||
exp_seed = int(cfg.get("seed", 42))
|
|
||||||
exp_disc = str(cfg.get("discriminator_path", ""))
|
|
||||||
exp_lora = str(cfg.get("lora_adapter", ""))
|
|
||||||
|
|
||||||
segment_samples = int(exp_seg_s * sample_rate)
|
|
||||||
|
|
||||||
# Filter clips long enough for this experiment
|
|
||||||
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
|
|
||||||
if not exp_clips:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
|
|
||||||
f"no clips >= {exp_seg_s}s", flush=True)
|
|
||||||
summary["experiments"].append({
|
|
||||||
"id": exp_id, "description": exp_desc,
|
|
||||||
"config": dict(cfg),
|
|
||||||
"results": {
|
|
||||||
"status": "failed",
|
|
||||||
"error": f"No clips >= {exp_seg_s}s",
|
|
||||||
"duration_seconds": 0,
|
|
||||||
},
|
|
||||||
"checkpoint_path": None,
|
|
||||||
"output_dir": str(output_root / exp_id),
|
|
||||||
})
|
|
||||||
_write_summary()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# ── Resolve discriminator path ───────────────────────
|
|
||||||
disc_path = None
|
|
||||||
if exp_disc:
|
|
||||||
disc_path = Path(exp_disc.strip())
|
|
||||||
if not disc_path.is_absolute():
|
|
||||||
disc_path = (
|
|
||||||
Path(folder_paths.get_output_directory()) / disc_path
|
|
||||||
)
|
|
||||||
if not disc_path.exists():
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"discriminator not found: {disc_path}",
|
|
||||||
flush=True)
|
|
||||||
disc_path = None
|
|
||||||
|
|
||||||
# ── Pre-generate LoRA mels (disk-cached) ─────────────
|
|
||||||
lora_mel_pairs = None
|
|
||||||
if exp_lora:
|
|
||||||
lora_path = Path(exp_lora.strip())
|
|
||||||
if not lora_path.is_absolute():
|
|
||||||
lora_path = Path(folder_paths.base_path) / lora_path
|
|
||||||
if lora_path.exists():
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
lora_mel_pairs = _pregenerate_lora_mels(
|
|
||||||
model, data_dir, str(lora_path),
|
|
||||||
device, dtype, sample_rate,
|
|
||||||
seq_cfg.duration, seed=exp_seed,
|
|
||||||
cache_dir=str(output_root),
|
|
||||||
text_clip_cache=text_clip_cache,
|
|
||||||
)
|
|
||||||
if not lora_mel_pairs:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"no LoRA mel pairs generated",
|
|
||||||
flush=True)
|
|
||||||
lora_mel_pairs = None
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"LoRA adapter not found: {lora_path}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# ── Output dir ───────────────────────────────────────
|
|
||||||
exp_dir = output_root / exp_id
|
|
||||||
exp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
|
|
||||||
|
|
||||||
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
|
|
||||||
flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] mode={train_mode} "
|
|
||||||
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
|
|
||||||
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
|
|
||||||
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"train_mode": train_mode, "steps": exp_steps,
|
|
||||||
"lr": exp_lr, "batch_size": exp_bs,
|
|
||||||
"segment_seconds": exp_seg_s,
|
|
||||||
"lambda_l2sp": exp_l2sp,
|
|
||||||
"use_gafilter": exp_gafilter,
|
|
||||||
"gafilter_kernel_size": exp_gaf_ks,
|
|
||||||
"lambda_phase": exp_phase,
|
|
||||||
"save_every": exp_save, "seed": exp_seed,
|
|
||||||
"discriminator_path": exp_disc,
|
|
||||||
"lora_adapter": exp_lora,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"checkpoint_path": None,
|
|
||||||
"output_dir": str(exp_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
# Ensure mel_converter is on device for this experiment
|
|
||||||
mel_converter.to(device)
|
|
||||||
|
|
||||||
# Fresh vocoder copy — _do_train modifies it in-place
|
|
||||||
vocoder_copy = copy.deepcopy(original_vocoder)
|
|
||||||
|
|
||||||
checkpoint_path = _do_train(
|
|
||||||
vocoder_copy, mel_converter, exp_clips,
|
|
||||||
device, dtype, strategy, feature_utils,
|
|
||||||
segment_samples, sample_rate,
|
|
||||||
train_mode, exp_steps, exp_lr, exp_bs,
|
|
||||||
exp_l2sp, exp_gafilter, exp_gaf_ks,
|
|
||||||
exp_phase, exp_save, exp_seed,
|
|
||||||
out_path, disc_path, pbar,
|
|
||||||
lora_mel_pairs,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
|
|
||||||
# Parse training CSV for loss history
|
|
||||||
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
|
|
||||||
loss_history = _parse_training_log(log_path)
|
|
||||||
log_interval = max(1, exp_steps // 20)
|
|
||||||
smoothed = (
|
|
||||||
_smooth_losses(loss_history)
|
|
||||||
if loss_history else []
|
|
||||||
)
|
|
||||||
|
|
||||||
final_loss = (
|
|
||||||
round(smoothed[-1], 6) if smoothed else None
|
|
||||||
)
|
|
||||||
min_loss = (
|
|
||||||
round(min(smoothed), 6) if smoothed else None
|
|
||||||
)
|
|
||||||
min_idx = (
|
|
||||||
smoothed.index(min(smoothed))
|
|
||||||
if smoothed else None
|
|
||||||
)
|
|
||||||
min_loss_step = (
|
|
||||||
(min_idx + 1) * log_interval
|
|
||||||
if min_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
loss_std = round(
|
|
||||||
float(np.std(loss_history[-quarter:])), 6
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loss_std = None
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval,
|
|
||||||
exp_save, exp_steps,
|
|
||||||
),
|
|
||||||
"loss_history": [
|
|
||||||
round(v, 6) for v in loss_history
|
|
||||||
],
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["checkpoint_path"] = checkpoint_path
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
|
|
||||||
f"failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
# Clean up vocoder copy to free VRAM
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
_exc[0] = e
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_worker, daemon=True)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
if _exc[0] is not None:
|
|
||||||
raise _exc[0]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 9. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[BigVGAN Scheduler] Sweep complete. "
|
|
||||||
f"Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 10. Comparison image
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
comparison_img.save(str(output_root / "loss_comparison.png"))
|
|
||||||
comparison_tensor = _pil_to_tensor(comparison_img)
|
|
||||||
|
|
||||||
return (str(summary_path), comparison_tensor)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,106 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetBrowser:
|
|
||||||
"""Browse a dataset.json file entry by entry using an integer index.
|
|
||||||
|
|
||||||
Each entry in the JSON is expected to have:
|
|
||||||
- "path" : base path (no extension) — directory that holds frame images
|
|
||||||
- "label" : text description of the clip
|
|
||||||
|
|
||||||
Derived outputs:
|
|
||||||
- video_path : path + ".mp4"
|
|
||||||
- audio_path : path + ".wav"
|
|
||||||
- frames_dir : path (the directory itself, for image-sequence loaders)
|
|
||||||
- label : entry["label"]
|
|
||||||
- count : total number of entries in the file
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset_json": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute or ComfyUI-relative path to a dataset.json file.",
|
|
||||||
}),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 9999,
|
|
||||||
"step": 1,
|
|
||||||
"tooltip": "Zero-based index of the entry to inspect.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("video_path", "audio_wav", "audio_flac", "features_path", "frames_dir", "mask_dir", "label", "max_index")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"path + '.mp4'",
|
|
||||||
"features/ + name + '.wav'",
|
|
||||||
"features/ + name + '.flac'",
|
|
||||||
"features/ + name + '.npz' (pre-extracted SelVA features)",
|
|
||||||
"path (image-sequence directory)",
|
|
||||||
"path + '_mask' (mask image-sequence directory)",
|
|
||||||
"Text label for this clip",
|
|
||||||
"count - 1 — wire to a primitive INT's max to constrain the index widget",
|
|
||||||
)
|
|
||||||
FUNCTION = "browse"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Reads a dataset.json produced by the SelVA dataset preparation pipeline "
|
|
||||||
"and exposes one entry at a time via an integer index. "
|
|
||||||
"Outputs the video path, audio path, frames directory, label, and total entry count."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-read the file every call so edits are picked up without restarting ComfyUI.
|
|
||||||
IS_CHANGED = classmethod(lambda cls, **_: float("nan"))
|
|
||||||
|
|
||||||
def browse(self, dataset_json: str, index: int):
|
|
||||||
p = Path(dataset_json.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.base_path) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA Dataset Browser] File not found: {p}")
|
|
||||||
|
|
||||||
with p.open("r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(data, list) or len(data) == 0:
|
|
||||||
raise ValueError(f"[SelVA Dataset Browser] Expected a non-empty JSON array in {p}")
|
|
||||||
|
|
||||||
count = len(data)
|
|
||||||
if index >= count:
|
|
||||||
raise IndexError(
|
|
||||||
f"[SelVA Dataset Browser] index {index} is out of range "
|
|
||||||
f"(dataset has {count} entries, last index is {count - 1})"
|
|
||||||
)
|
|
||||||
entry = data[index]
|
|
||||||
|
|
||||||
base = entry["path"]
|
|
||||||
label = entry.get("label", "")
|
|
||||||
|
|
||||||
p_base = Path(base)
|
|
||||||
feat_base = str(p_base.parent / "features" / p_base.name)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[SelVA Dataset Browser] {index + 1}/{count} label='{label}' base={base}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
base + ".mp4",
|
|
||||||
feat_base + ".wav",
|
|
||||||
feat_base + ".flac",
|
|
||||||
feat_base + ".npz",
|
|
||||||
base,
|
|
||||||
base + "_mask",
|
|
||||||
label,
|
|
||||||
count - 1,
|
|
||||||
)
|
|
||||||
@@ -1,788 +0,0 @@
|
|||||||
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
|
|
||||||
|
|
||||||
Typical chain:
|
|
||||||
SelvaDatasetLoader
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetResampler (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetLUFSNormalizer (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetCompressor (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetHfSmoother (optional — batch HF attenuation)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetInspector (optional)
|
|
||||||
↓ AUDIO_DATASET + STRING report
|
|
||||||
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
# ComfyUI custom type name — passed between all dataset pipeline nodes
|
|
||||||
AUDIO_DATASET = "AUDIO_DATASET"
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
|
|
||||||
_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg
|
|
||||||
|
|
||||||
|
|
||||||
def _load_audio(path: Path):
|
|
||||||
"""Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues."""
|
|
||||||
if path.suffix.lower() in _SOUNDFILE_EXTS:
|
|
||||||
import soundfile as sf
|
|
||||||
wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C]
|
|
||||||
wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L]
|
|
||||||
else:
|
|
||||||
wav, sr = torchaudio.load(str(path)) # [C, L]
|
|
||||||
wav = wav.unsqueeze(0).float() # [1, C, L]
|
|
||||||
return wav, sr
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetLoader:
|
|
||||||
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
|
|
||||||
|
|
||||||
def load(self, folder: str):
|
|
||||||
folder = Path(folder.strip())
|
|
||||||
if not folder.exists():
|
|
||||||
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
|
|
||||||
|
|
||||||
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
|
|
||||||
if not files:
|
|
||||||
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
|
|
||||||
|
|
||||||
dataset = []
|
|
||||||
for f in sorted(files):
|
|
||||||
try:
|
|
||||||
wav, sr = _load_audio(f)
|
|
||||||
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
|
|
||||||
return (dataset,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetResampler:
|
|
||||||
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_sr": ("INT", {
|
|
||||||
"default": 44100, "min": 8000, "max": 192000,
|
|
||||||
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "resample"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
|
|
||||||
|
|
||||||
def resample(self, dataset, target_sr: int):
|
|
||||||
import soxr
|
|
||||||
|
|
||||||
out = []
|
|
||||||
changed = 0
|
|
||||||
for item in dataset:
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
if sr == target_sr:
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
# soxr expects [L, C] (time-first), float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
|
|
||||||
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
|
|
||||||
changed += 1
|
|
||||||
|
|
||||||
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetLUFSNormalizer:
|
|
||||||
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
|
|
||||||
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
out = []
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
try:
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
except Exception:
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_norm = wav * gain_linear
|
|
||||||
|
|
||||||
# True peak limit
|
|
||||||
peak = wav_norm.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_norm = wav_norm * (tp_linear / peak)
|
|
||||||
|
|
||||||
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
|
|
||||||
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetCompressor:
|
|
||||||
"""Apply mild parallel compression to reduce within-clip loudness variance.
|
|
||||||
|
|
||||||
Uses pedalboard.Compressor (2:1–3:1 ratio). Parallel (New York) style:
|
|
||||||
blends compressed signal with dry so transients are preserved while
|
|
||||||
the dynamic range is gently tightened. Apply after LUFS normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"threshold_db": ("FLOAT", {
|
|
||||||
"default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0,
|
|
||||||
"tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.",
|
|
||||||
}),
|
|
||||||
"ratio": ("FLOAT", {
|
|
||||||
"default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5,
|
|
||||||
"tooltip": "Compression ratio. 2:1–3:1 is mild; stay below 4:1 to avoid pumping.",
|
|
||||||
}),
|
|
||||||
"attack_ms": ("FLOAT", {
|
|
||||||
"default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0,
|
|
||||||
"tooltip": "Attack time in ms. Slower attack preserves transients.",
|
|
||||||
}),
|
|
||||||
"release_ms": ("FLOAT", {
|
|
||||||
"default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0,
|
|
||||||
"tooltip": "Release time in ms.",
|
|
||||||
}),
|
|
||||||
"mix": ("FLOAT", {
|
|
||||||
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.3–0.5 is typical.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "compress"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Mild parallel compression to reduce within-clip dynamic range. "
|
|
||||||
"Blends compressed signal with dry at the given mix ratio. "
|
|
||||||
"Apply after LUFS normalization."
|
|
||||||
)
|
|
||||||
|
|
||||||
def compress(self, dataset, threshold_db: float, ratio: float,
|
|
||||||
attack_ms: float, release_ms: float, mix: float):
|
|
||||||
from pedalboard import Compressor, Pedalboard
|
|
||||||
|
|
||||||
board = Pedalboard([Compressor(
|
|
||||||
threshold_db=threshold_db,
|
|
||||||
ratio=ratio,
|
|
||||||
attack_ms=attack_ms,
|
|
||||||
release_ms=release_ms,
|
|
||||||
)])
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pedalboard expects [C, L] float32 numpy
|
|
||||||
wav_np = wav.float().numpy() # [C, L]
|
|
||||||
compressed = board(wav_np, sr) # [C, L]
|
|
||||||
mixed = (1.0 - mix) * wav_np + mix * compressed
|
|
||||||
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[DatasetCompressor] {len(out)} clips compressed "
|
|
||||||
f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
|
|
||||||
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
|
|
||||||
|
|
||||||
Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT.
|
|
||||||
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
|
|
||||||
"""
|
|
||||||
if sr < 32000:
|
|
||||||
return False # can't assess HF at low sample rates
|
|
||||||
|
|
||||||
n_fft = 2048
|
|
||||||
hop = 512
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
window = torch.hann_window(n_fft, device=mono.device)
|
|
||||||
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
|
|
||||||
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
|
|
||||||
|
|
||||||
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
|
|
||||||
band_lo = (freqs >= 1000) & (freqs < 5000)
|
|
||||||
band_hi = (freqs >= 15000) & (freqs < 20000)
|
|
||||||
|
|
||||||
if band_hi.sum() == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
|
|
||||||
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
|
|
||||||
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
|
|
||||||
return ratio_db > 40.0
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_snr(wav: torch.Tensor) -> float:
|
|
||||||
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
if mono.shape[0] < 2048:
|
|
||||||
return 60.0 # clip too short to frame — assume clean
|
|
||||||
frames = mono.unfold(0, 2048, 512) # [N, 2048]
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt() # [N]
|
|
||||||
p95 = torch.quantile(rms, 0.95).item()
|
|
||||||
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
|
|
||||||
return 20.0 * np.log10(p95 / p05 + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetInspector:
|
|
||||||
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"skip_rejected": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "If True, flagged clips are removed from the output dataset. "
|
|
||||||
"If False, all clips pass through but the report still lists issues.",
|
|
||||||
}),
|
|
||||||
"min_snr_db": ("FLOAT", {
|
|
||||||
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
|
|
||||||
"tooltip": "Clips with estimated SNR below this value are flagged.",
|
|
||||||
}),
|
|
||||||
"check_codec_artifacts": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
|
||||||
}),
|
|
||||||
"max_silence_fraction": ("FLOAT", {
|
|
||||||
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
|
|
||||||
"(< -60 dBFS). Set to 0 to disable silence detection.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET, "STRING")
|
|
||||||
RETURN_NAMES = ("dataset", "report")
|
|
||||||
FUNCTION = "inspect"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Analyze each clip for clipping, low SNR, and codec artifacts. "
|
|
||||||
"Outputs a filtered AUDIO_DATASET and a text report. "
|
|
||||||
"Connect report to a ShowText node to preview in the UI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
|
|
||||||
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
|
|
||||||
clean = []
|
|
||||||
flagged = []
|
|
||||||
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
issues = []
|
|
||||||
|
|
||||||
# Clipping
|
|
||||||
peak = wav.abs().max().item()
|
|
||||||
if peak > 0.99:
|
|
||||||
issues.append(f"clipping (peak={peak:.3f})")
|
|
||||||
|
|
||||||
# Low SNR
|
|
||||||
snr = _estimate_snr(wav)
|
|
||||||
if snr < min_snr_db:
|
|
||||||
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
|
|
||||||
|
|
||||||
# Codec artifacts
|
|
||||||
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
|
||||||
issues.append("codec artifact (HF shelf > 15 kHz)")
|
|
||||||
|
|
||||||
# Silence detection
|
|
||||||
if max_silence_fraction > 0:
|
|
||||||
mono = wav[0].mean(0)
|
|
||||||
if mono.shape[0] >= 2048:
|
|
||||||
frames = mono.unfold(0, 2048, 512)
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt()
|
|
||||||
silent_frac = (rms < 1e-3).float().mean().item()
|
|
||||||
if silent_frac > max_silence_fraction:
|
|
||||||
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
|
|
||||||
|
|
||||||
if issues:
|
|
||||||
flagged.append(name)
|
|
||||||
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
|
||||||
if not skip_rejected:
|
|
||||||
clean.append(item)
|
|
||||||
else:
|
|
||||||
clean.append(item)
|
|
||||||
lines.append(f" OK {name}")
|
|
||||||
|
|
||||||
lines.append("=" * 40)
|
|
||||||
lines.append(
|
|
||||||
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
|
|
||||||
+ (" (removed)" if skip_rejected else " (kept)")
|
|
||||||
)
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(f"[DatasetInspector]\n{report}", flush=True)
|
|
||||||
return (clean, report)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetItemExtractor:
|
|
||||||
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
|
|
||||||
|
|
||||||
Bridges the dataset pipeline to any node that accepts a standard AUDIO
|
|
||||||
input — save audio, HF Smoother, Spectral Matcher, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0, "min": 0, "max": 9999,
|
|
||||||
"tooltip": "0-based index. Wraps around if index >= dataset length.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("audio", "name", "total")
|
|
||||||
FUNCTION = "extract"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Extract one clip from an AUDIO_DATASET by index. "
|
|
||||||
"Returns standard AUDIO (compatible with all audio nodes), "
|
|
||||||
"the clip name, and the total dataset length."
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract(self, dataset, index: int):
|
|
||||||
if not dataset:
|
|
||||||
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
|
|
||||||
idx = index % len(dataset)
|
|
||||||
item = dataset[idx]
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
print(
|
|
||||||
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
|
|
||||||
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (audio, item["name"], len(dataset))
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetSaver:
|
|
||||||
"""Save all clips in an AUDIO_DATASET to disk as FLAC files.
|
|
||||||
|
|
||||||
Optionally copies matching .npz feature files from a source directory,
|
|
||||||
keeping FLAC/NPZ pairs in sync after the inspector has filtered clips.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"output_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to output folder. Created if it does not exist.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"npz_source_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. "
|
|
||||||
"Missing NPZs are warned but do not abort the save.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("report",)
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
FUNCTION = "save"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Save every clip in an AUDIO_DATASET to output_dir as FLAC. "
|
|
||||||
"If npz_source_dir is provided, copies the matching .npz file for each clip — "
|
|
||||||
"so rejected clips never get their NPZ copied."
|
|
||||||
)
|
|
||||||
|
|
||||||
def save(self, dataset, output_dir: str, npz_source_dir: str = ""):
|
|
||||||
import shutil
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
out = Path(output_dir.strip())
|
|
||||||
out.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None
|
|
||||||
|
|
||||||
saved = 0
|
|
||||||
npz_copied = 0
|
|
||||||
npz_missing = []
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
name = item["name"]
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# soundfile wants [L] mono or [L, C] multichannel, float32
|
|
||||||
wav_np = wav.permute(1, 0).float().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
flac_path = out / f"{name}.flac"
|
|
||||||
sf.write(str(flac_path), wav_np, sr, subtype="PCM_24")
|
|
||||||
saved += 1
|
|
||||||
|
|
||||||
if npz_src is not None:
|
|
||||||
# Augmented clips store their origin name — use it to find the .npz
|
|
||||||
lookup = item.get("origin_name", name)
|
|
||||||
npz_path = npz_src / f"{lookup}.npz"
|
|
||||||
if npz_path.exists():
|
|
||||||
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
|
|
||||||
npz_copied += 1
|
|
||||||
else:
|
|
||||||
npz_missing.append(name)
|
|
||||||
|
|
||||||
lines = [
|
|
||||||
f"[DatasetSaver] Saved {saved} clips → {out}",
|
|
||||||
]
|
|
||||||
if npz_src is not None:
|
|
||||||
lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}")
|
|
||||||
for n in npz_missing:
|
|
||||||
lines.append(f" MISSING NPZ: {n}")
|
|
||||||
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(report, flush=True)
|
|
||||||
return (report,)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaDatasetSpectralMatcher:
|
|
||||||
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
|
|
||||||
|
|
||||||
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
|
|
||||||
individual AUDIO items. Same parameters — see that node for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"mode": (["44k", "16k"], {
|
|
||||||
"tooltip": "Must match the SelVA model you are training. "
|
|
||||||
"44k = large model, 16k = small model.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
|
|
||||||
}),
|
|
||||||
"max_gain_db": ("FLOAT", {
|
|
||||||
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
|
|
||||||
"tooltip": "Clamps per-band gain to ±dB.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Apply adaptive spectral matching to every clip in a dataset. "
|
|
||||||
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
|
|
||||||
"VAE's expected distribution."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
|
|
||||||
from .selva_audio_preprocessors import SelvaSpectralMatcher
|
|
||||||
matcher = SelvaSpectralMatcher()
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
(result,) = matcher.process(audio, mode, strength, max_gain_db)
|
|
||||||
new_item = dict(item) # preserve origin_name and any extra keys
|
|
||||||
new_item["waveform"] = result["waveform"]
|
|
||||||
new_item["sample_rate"] = result["sample_rate"]
|
|
||||||
out.append(new_item)
|
|
||||||
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
|
|
||||||
f"mode={mode} strength={strength}", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetHfSmoother:
|
|
||||||
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
|
|
||||||
|
|
||||||
Wraps SelvaHfSmoother so it works on batch datasets instead of
|
|
||||||
individual AUDIO items. Same parameters — see that node for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
|
|
||||||
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
|
|
||||||
}),
|
|
||||||
"blend": ("FLOAT", {
|
|
||||||
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = original, 1 = fully filtered.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Apply soft HF attenuation to every clip in a dataset. "
|
|
||||||
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
|
|
||||||
"with the original to tame extreme HF content."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, dataset, cutoff_hz: float, blend: float):
|
|
||||||
from .selva_audio_preprocessors import SelvaHfSmoother
|
|
||||||
smoother = SelvaHfSmoother()
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
(result,) = smoother.process(audio, cutoff_hz, blend)
|
|
||||||
new_item = dict(item) # preserve origin_name and any extra keys
|
|
||||||
new_item["waveform"] = result["waveform"]
|
|
||||||
new_item["sample_rate"] = result["sample_rate"]
|
|
||||||
out.append(new_item)
|
|
||||||
print(f"[DatasetHfSmoother] {len(out)} clips processed "
|
|
||||||
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Dataset augmenter ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaDatasetAugmenter:
|
|
||||||
"""Create augmented variants of each clip to expand a small dataset.
|
|
||||||
|
|
||||||
Supports gain variation (always available) and optionally pitch shift
|
|
||||||
and time stretch via audiomentations. Install audiomentations for the
|
|
||||||
full feature set: ``pip install audiomentations``
|
|
||||||
|
|
||||||
Each original clip produces ``variants_per_clip`` augmented copies.
|
|
||||||
Originals are kept by default (toggle ``keep_originals``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"variants_per_clip": ("INT", {
|
|
||||||
"default": 2, "min": 1, "max": 20,
|
|
||||||
"tooltip": "Number of augmented copies per original clip.",
|
|
||||||
}),
|
|
||||||
"gain_range_db": ("FLOAT", {
|
|
||||||
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
|
|
||||||
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"pitch_range_semitones": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
|
|
||||||
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
|
|
||||||
}),
|
|
||||||
"time_stretch_range": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
|
|
||||||
"tooltip": "Random time stretch ±fraction (0.1 = 90%–110% speed). "
|
|
||||||
"Requires audiomentations. 0 = disabled.",
|
|
||||||
}),
|
|
||||||
"keep_originals": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Include the original unaugmented clips in the output.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "augment"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Create augmented variants of each clip (gain, pitch, time stretch) "
|
|
||||||
"to expand small training datasets. Gain is always available; pitch and "
|
|
||||||
"time stretch require audiomentations (pip install audiomentations)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
|
|
||||||
seed: int, pitch_range_semitones: float = 0.0,
|
|
||||||
time_stretch_range: float = 0.0, keep_originals: bool = True):
|
|
||||||
rng = np.random.RandomState(seed)
|
|
||||||
|
|
||||||
# Try audiomentations for pitch/stretch
|
|
||||||
use_am = False
|
|
||||||
am_compose = None
|
|
||||||
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
|
|
||||||
if needs_am:
|
|
||||||
try:
|
|
||||||
import audiomentations as am
|
|
||||||
transforms = []
|
|
||||||
if pitch_range_semitones > 0:
|
|
||||||
transforms.append(am.PitchShift(
|
|
||||||
min_semitones=-pitch_range_semitones,
|
|
||||||
max_semitones=pitch_range_semitones,
|
|
||||||
p=0.5,
|
|
||||||
))
|
|
||||||
if time_stretch_range > 0:
|
|
||||||
transforms.append(am.TimeStretch(
|
|
||||||
min_rate=1.0 - time_stretch_range,
|
|
||||||
max_rate=1.0 + time_stretch_range,
|
|
||||||
leave_length_unchanged=True,
|
|
||||||
p=0.5,
|
|
||||||
))
|
|
||||||
am_compose = am.Compose(transforms)
|
|
||||||
use_am = True
|
|
||||||
except ImportError:
|
|
||||||
print("[DatasetAugmenter] audiomentations not installed — "
|
|
||||||
"pitch_shift and time_stretch disabled. "
|
|
||||||
"Install: pip install audiomentations", flush=True)
|
|
||||||
|
|
||||||
out = []
|
|
||||||
if keep_originals:
|
|
||||||
out.extend(dataset)
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"] # [1, C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
|
|
||||||
for v in range(variants_per_clip):
|
|
||||||
# Gain variation (always applied)
|
|
||||||
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
|
|
||||||
gain_lin = 10.0 ** (gain_db / 20.0)
|
|
||||||
wav_aug = wav * gain_lin
|
|
||||||
|
|
||||||
# Pitch/stretch via audiomentations
|
|
||||||
if use_am and am_compose is not None:
|
|
||||||
wav_np = wav_aug[0].numpy() # [C, L] float32
|
|
||||||
if wav_np.shape[0] == 1:
|
|
||||||
wav_np = wav_np[0] # [L] mono for audiomentations
|
|
||||||
wav_np = am_compose(samples=wav_np, sample_rate=sr)
|
|
||||||
if wav_np.ndim == 1:
|
|
||||||
wav_np = wav_np[np.newaxis, :] # back to [1, L]
|
|
||||||
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
|
|
||||||
|
|
||||||
# Prevent clipping
|
|
||||||
peak = wav_aug.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
wav_aug = wav_aug / peak
|
|
||||||
|
|
||||||
out.append({
|
|
||||||
"waveform": wav_aug,
|
|
||||||
"sample_rate": sr,
|
|
||||||
"name": f"{name}_aug{v:02d}",
|
|
||||||
"origin_name": name,
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
|
|
||||||
f"gain=±{gain_range_db:.1f}dB"
|
|
||||||
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
|
|
||||||
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
|
|
||||||
flush=True)
|
|
||||||
return (out,)
|
|
||||||
@@ -1,515 +0,0 @@
|
|||||||
"""SelVA DITTO Optimizer.
|
|
||||||
|
|
||||||
Inference-time noise optimization: optimizes the initial noise latent x_0
|
|
||||||
using a style loss against target style reference clips, backpropagating through the
|
|
||||||
ODE solver. All model weights remain frozen — only x_0 changes.
|
|
||||||
|
|
||||||
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
|
|
||||||
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
|
|
||||||
|
|
||||||
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
|
|
||||||
against target style reference clips. Runs entirely before the vocoder — optimization
|
|
||||||
only requires the DiT + VAE decoder, not BigVGAN.
|
|
||||||
|
|
||||||
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
|
|
||||||
forward pass activations) instead of O(N steps). Backward recomputes each
|
|
||||||
step's activations on demand.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
|
||||||
|
|
||||||
|
|
||||||
def _load_wav(path):
|
|
||||||
"""Load audio file to [channels, samples] float32 tensor."""
|
|
||||||
try:
|
|
||||||
return torchaudio.load(str(path))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import soundfile as sf
|
|
||||||
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
|
|
||||||
wav = torch.from_numpy(data.T)
|
|
||||||
return wav, sr
|
|
||||||
|
|
||||||
|
|
||||||
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
|
||||||
"""Style loss between generated mel and precomputed reference statistics.
|
|
||||||
|
|
||||||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
|
||||||
ref_mean: [n_mels] mean spectrum of reference clips (detached)
|
|
||||||
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
|
|
||||||
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
|
|
||||||
Start at 0; enable only if mean-only optimization converges
|
|
||||||
without noise, then increase slowly (0.01–0.1).
|
|
||||||
"""
|
|
||||||
m = mel_gen.squeeze(0) # [n_mels, T]
|
|
||||||
|
|
||||||
# Mean spectrum loss — captures spectral envelope
|
|
||||||
gen_mean = m.mean(dim=-1) # [n_mels]
|
|
||||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
|
||||||
|
|
||||||
if gram_weight <= 0.0:
|
|
||||||
return loss_mean
|
|
||||||
|
|
||||||
# Gram matrix loss — captures timbral texture (can add noise if too high)
|
|
||||||
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
|
||||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
|
||||||
|
|
||||||
return loss_mean + gram_weight * loss_gram
|
|
||||||
|
|
||||||
|
|
||||||
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
|
|
||||||
"""Style loss computed directly in VAE latent space.
|
|
||||||
|
|
||||||
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
|
|
||||||
ref_mean: [C_lat] mean latent vector of reference clips
|
|
||||||
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
|
|
||||||
gram_weight: weight for Gram component — 0 = mean only (recommended start)
|
|
||||||
|
|
||||||
Operating in latent space avoids backprop through the VAE decoder, which
|
|
||||||
is @torch.inference_mode() and produces noisy, unstable gradients.
|
|
||||||
"""
|
|
||||||
# Mean latent loss — matches average activation per channel
|
|
||||||
gen_mean = z.mean(dim=0) # [C_lat]
|
|
||||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
|
||||||
|
|
||||||
if gram_weight <= 0.0:
|
|
||||||
return loss_mean
|
|
||||||
|
|
||||||
# Gram matrix — inter-channel covariance, position-invariant
|
|
||||||
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
|
|
||||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
|
||||||
|
|
||||||
return loss_mean + gram_weight * loss_gram
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDittoOptimizer:
|
|
||||||
"""DITTO inference-time noise optimization.
|
|
||||||
|
|
||||||
Freezes all model weights and optimizes only the initial noise latent x_0
|
|
||||||
to make the generated audio sound like the target style reference clips.
|
|
||||||
No training data or gradient updates to the model — per-video per-run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"features": ("SELVA_FEATURES",),
|
|
||||||
"prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": True,
|
|
||||||
"tooltip": "Sound description. Leave empty to use features prompt.",
|
|
||||||
}),
|
|
||||||
"negative_prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": False,
|
|
||||||
}),
|
|
||||||
"reference_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory with target style reference audio files (.wav/.flac/.mp3). "
|
|
||||||
"Reference mel statistics are precomputed from these once.",
|
|
||||||
}),
|
|
||||||
"n_opt_steps": ("INT", {
|
|
||||||
"default": 50, "min": 5, "max": 500,
|
|
||||||
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
|
|
||||||
"each step requires ~2 DiT forward passes.",
|
|
||||||
}),
|
|
||||||
"opt_lr": ("FLOAT", {
|
|
||||||
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
|
|
||||||
"tooltip": "Adam learning rate for x_0 optimization. "
|
|
||||||
"0.02–0.05 is recommended; 0.1 (paper default) causes oscillation.",
|
|
||||||
}),
|
|
||||||
"n_ode_steps": ("INT", {
|
|
||||||
"default": 10, "min": 5, "max": 50,
|
|
||||||
"tooltip": "Euler ODE steps run during each optimization iteration. "
|
|
||||||
"Lower = faster optimization (10–15 is a good trade-off). "
|
|
||||||
"Final generation always uses the steps parameter below.",
|
|
||||||
}),
|
|
||||||
"n_grad_steps": ("INT", {
|
|
||||||
"default": 5, "min": 1, "max": 50,
|
|
||||||
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
|
|
||||||
"Higher = more accurate gradient, more VRAM. "
|
|
||||||
"Must be ≤ n_ode_steps. 5 is a good default.",
|
|
||||||
}),
|
|
||||||
"style_weight": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
|
|
||||||
"tooltip": "Weight of the target style style loss. High values push harder toward "
|
|
||||||
"target style style but add noise. Start at 0.1 and increase slowly.",
|
|
||||||
}),
|
|
||||||
"gram_weight": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
|
|
||||||
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
|
|
||||||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
|
||||||
"0.1 adds texture matching but can introduce white noise.",
|
|
||||||
}),
|
|
||||||
"anchor_weight": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
|
||||||
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
|
|
||||||
"Prevents optimization from pushing x0 out of the flow's "
|
|
||||||
"expected distribution (which causes white noise). "
|
|
||||||
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
|
|
||||||
}),
|
|
||||||
"steps": ("INT", {
|
|
||||||
"default": 25, "min": 1, "max": 200,
|
|
||||||
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
|
||||||
}),
|
|
||||||
"cfg_strength": ("FLOAT", {
|
|
||||||
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"normalize": ("BOOLEAN", {"default": True}),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward target style style.",)
|
|
||||||
FUNCTION = "optimize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"DITTO inference-time noise optimization (arXiv:2401.12179). "
|
|
||||||
"Optimizes the initial noise latent x_0 to match target style reference clips "
|
|
||||||
"via mel statistics style loss, backpropagating through the ODE. "
|
|
||||||
"All model weights frozen — zero quality degradation risk."
|
|
||||||
)
|
|
||||||
|
|
||||||
def optimize(self, model, features, prompt, negative_prompt,
|
|
||||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize=True, target_lufs=-27.0):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
net_generator = model["generator"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
mel_converter = feature_utils.mel_converter
|
|
||||||
|
|
||||||
# Validate variant match
|
|
||||||
feat_variant = features.get("variant")
|
|
||||||
if feat_variant is not None and feat_variant != model["variant"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
|
|
||||||
f"Re-run Feature Extractor."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not prompt or not prompt.strip():
|
|
||||||
prompt = features.get("prompt", "")
|
|
||||||
|
|
||||||
# Resolve duration and seq_cfg
|
|
||||||
duration = features.get("duration", 0)
|
|
||||||
if duration <= 0:
|
|
||||||
raise ValueError("[DITTO] Features contain no duration field.")
|
|
||||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
|
||||||
sample_rate = seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
# Load reference clips and encode to latent space.
|
|
||||||
# Style loss is computed in latent space (after net_generator.unnormalize)
|
|
||||||
# rather than mel space — this avoids backpropagating through the VAE
|
|
||||||
# decoder (which is @torch.inference_mode() and produces noisy gradients).
|
|
||||||
ref_dir = Path(reference_dir.strip())
|
|
||||||
if not ref_dir.is_absolute():
|
|
||||||
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
|
||||||
if not ref_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
|
|
||||||
|
|
||||||
ref_files = []
|
|
||||||
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
|
|
||||||
ref_files.extend(ref_dir.rglob(ext))
|
|
||||||
if not ref_files:
|
|
||||||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
|
||||||
|
|
||||||
if not hasattr(feature_utils.tod.vae, "encoder"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
|
|
||||||
"Reload the model with the encoder enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
|
||||||
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
|
||||||
|
|
||||||
ref_latents = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for rf in ref_files:
|
|
||||||
try:
|
|
||||||
wav, sr = _load_wav(rf)
|
|
||||||
if wav.shape[0] > 1:
|
|
||||||
wav = wav.mean(0, keepdim=True)
|
|
||||||
if sr != sample_rate:
|
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
|
||||||
wav = wav.squeeze(0).to(device, torch.float32)
|
|
||||||
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
|
|
||||||
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
|
|
||||||
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
|
||||||
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
|
||||||
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
if not ref_latents:
|
|
||||||
raise RuntimeError("[DITTO] No usable reference clips.")
|
|
||||||
|
|
||||||
# Precompute reference latent statistics (done once — detached, no grad)
|
|
||||||
with torch.no_grad():
|
|
||||||
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
|
|
||||||
ref_mean = all_means.mean(0) # [C_lat]
|
|
||||||
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
|
|
||||||
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
|
|
||||||
|
|
||||||
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
|
|
||||||
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
|
||||||
f"grad_steps={n_grad_steps}", flush=True)
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(device)
|
|
||||||
feature_utils.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
|
|
||||||
|
|
||||||
_result = [None]
|
|
||||||
_exc = [None]
|
|
||||||
|
|
||||||
def _worker():
|
|
||||||
try:
|
|
||||||
_result[0] = _do_optimize(
|
|
||||||
net_generator, feature_utils, mel_converter,
|
|
||||||
features, prompt, negative_prompt,
|
|
||||||
ref_mean, ref_gram,
|
|
||||||
seq_cfg, sample_rate, device, dtype,
|
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize, target_lufs, pbar,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_exc[0] = e
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_worker, daemon=True)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(get_offload_device())
|
|
||||||
feature_utils.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
if _exc[0] is not None:
|
|
||||||
raise _exc[0]
|
|
||||||
return (_result[0],)
|
|
||||||
|
|
||||||
|
|
||||||
def _do_optimize(net_generator, feature_utils, mel_converter,
|
|
||||||
features, prompt, negative_prompt,
|
|
||||||
ref_mean, ref_gram,
|
|
||||||
seq_cfg, sample_rate, device, dtype,
|
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize, target_lufs, pbar):
|
|
||||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
|
||||||
|
|
||||||
# Strip inference flags from ref stats (came from main thread) and cast to
|
|
||||||
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
|
|
||||||
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
|
|
||||||
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
|
|
||||||
ref_mean = ref_mean.clone().detach().to(dtype)
|
|
||||||
ref_gram = ref_gram.clone().detach().to(dtype)
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
clip_f = features["clip_features"].to(device, dtype).clone()
|
|
||||||
sync_f = features["sync_features"].to(device, dtype).clone()
|
|
||||||
|
|
||||||
# Strip inference-mode flags from all model weights and buffers BEFORE any
|
|
||||||
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
|
|
||||||
# operations on inference tensors produce inference tensors, so conditions
|
|
||||||
# computed from tainted weights would also be tainted. clone() outside
|
|
||||||
# inference_mode produces a normal tensor regardless of the source flag.
|
|
||||||
def _strip_inference(module):
|
|
||||||
for mod in module.modules():
|
|
||||||
for name, buf in list(mod._buffers.items()):
|
|
||||||
if buf is not None:
|
|
||||||
mod._buffers[name] = buf.clone()
|
|
||||||
for name, param in list(mod._parameters.items()):
|
|
||||||
if param is not None:
|
|
||||||
mod._parameters[name] = torch.nn.Parameter(
|
|
||||||
param.data.clone(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
_strip_inference(net_generator)
|
|
||||||
_strip_inference(feature_utils)
|
|
||||||
_strip_inference(mel_converter)
|
|
||||||
|
|
||||||
net_generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=clip_f.shape[1],
|
|
||||||
sync_seq_len=sync_f.shape[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
text_clip = feature_utils.encode_text_clip([prompt])
|
|
||||||
|
|
||||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
|
||||||
if negative_prompt.strip() else None
|
|
||||||
|
|
||||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
|
||||||
empty_conditions = net_generator.get_empty_conditions(
|
|
||||||
bs=1, negative_text_features=neg_text_clip
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clone all tensors inside conditions/empty_conditions to ensure no inference
|
|
||||||
# flags survived from intermediate computations inside preprocess_conditions.
|
|
||||||
def _clone_nested(obj):
|
|
||||||
if isinstance(obj, torch.Tensor):
|
|
||||||
return obj.clone()
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
return {k: _clone_nested(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
return type(obj)(_clone_nested(v) for v in obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
conditions = _clone_nested(conditions)
|
|
||||||
empty_conditions = _clone_nested(empty_conditions)
|
|
||||||
|
|
||||||
# Initial noise — x_0 is the parameter we optimize
|
|
||||||
x0_init = torch.randn(
|
|
||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
|
||||||
device=device, dtype=dtype,
|
|
||||||
)
|
|
||||||
x0 = torch.nn.Parameter(x0_init.clone())
|
|
||||||
x0_init = x0_init.detach() # anchor — kept fixed, no grad
|
|
||||||
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
|
||||||
|
|
||||||
# n_grad_steps must not exceed n_ode_steps
|
|
||||||
n_grad_steps = min(n_grad_steps, n_ode_steps)
|
|
||||||
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
|
|
||||||
|
|
||||||
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
print(f"[DITTO] Optimizing x_0 "
|
|
||||||
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
|
|
||||||
|
|
||||||
# Freeze all model weights (double-check — should already be frozen at inference)
|
|
||||||
net_generator.requires_grad_(False)
|
|
||||||
feature_utils.requires_grad_(False)
|
|
||||||
mel_converter.requires_grad_(False)
|
|
||||||
|
|
||||||
for opt_step in range(n_opt_steps):
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
|
|
||||||
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
|
|
||||||
# Detach from x0 so Phase 1 does not build a computation graph.
|
|
||||||
with torch.no_grad():
|
|
||||||
x = x0.detach()
|
|
||||||
for i in range(n_free_steps):
|
|
||||||
t = ts[i]
|
|
||||||
dt = ts[i + 1] - t
|
|
||||||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
|
||||||
x = x + dt * flow
|
|
||||||
|
|
||||||
# Straight-through estimator: reconnect x to x0's gradient path by
|
|
||||||
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
|
|
||||||
# creates a grad_fn pointing back to x0, so loss.backward() will
|
|
||||||
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
|
|
||||||
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
|
|
||||||
# treated as identity for gradient purposes (truncated BPTT).
|
|
||||||
#
|
|
||||||
# x may carry an inference tensor flag from Phase 1 (derived from
|
|
||||||
# conditions which were built outside inference_mode but may have
|
|
||||||
# propagated the flag). .clone() strips it so the STE addition does
|
|
||||||
# not try to save an inference tensor for backward.
|
|
||||||
x = x.clone()
|
|
||||||
x = x + (x0 - x0.detach())
|
|
||||||
|
|
||||||
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
|
||||||
for i in range(n_free_steps, n_ode_steps):
|
|
||||||
t = ts[i]
|
|
||||||
dt = ts[i + 1] - t
|
|
||||||
|
|
||||||
# Gradient checkpointing: recompute forward during backward,
|
|
||||||
# avoiding storage of DiT activations for each step.
|
|
||||||
def _ode_step(x_in, t=t):
|
|
||||||
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
|
|
||||||
|
|
||||||
flow = torch.utils.checkpoint.checkpoint(
|
|
||||||
_ode_step, x, use_reentrant=False
|
|
||||||
)
|
|
||||||
x = x + dt * flow
|
|
||||||
|
|
||||||
# ── Style loss in latent space ───────────────────────────────────────
|
|
||||||
# Unnormalize x back to VAE latent space — fully differentiable, no
|
|
||||||
# decode needed. ref_mean/ref_gram are computed from encoded reference
|
|
||||||
# clips in the same space. Avoids backprop through VAE decoder which
|
|
||||||
# is @torch.inference_mode() and produces noisy gradients.
|
|
||||||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
|
||||||
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
|
||||||
|
|
||||||
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
|
|
||||||
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
|
|
||||||
# the ODE into an out-of-distribution region that decodes as white noise.
|
|
||||||
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
|
|
||||||
loss = style_loss + anchor_loss
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
|
||||||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
|
||||||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
|
|
||||||
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
|
|
||||||
f"x0_std={x0.data.std().item():.3f}", flush=True)
|
|
||||||
|
|
||||||
# ── Final generation with optimized x_0 ─────────────────────────────────
|
|
||||||
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
|
|
||||||
x = x0.detach()
|
|
||||||
for i in range(steps):
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
t = fm_ts[i]
|
|
||||||
dt = fm_ts[i + 1] - t
|
|
||||||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
|
||||||
x = x + dt * flow
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
x1_unnorm = net_generator.unnormalize(x)
|
|
||||||
spec = feature_utils.decode(x1_unnorm)
|
|
||||||
audio = feature_utils.vocode(spec)
|
|
||||||
|
|
||||||
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
audio = audio.float()
|
|
||||||
if audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
|
||||||
audio = audio.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
if normalize:
|
|
||||||
target_rms = 10 ** (target_lufs / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = audio * (target_rms / rms)
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
if peak > 1.0:
|
|
||||||
audio = audio / peak
|
|
||||||
|
|
||||||
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
|
||||||
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
|
|
||||||
@@ -35,6 +35,66 @@ def _resize_frames(frames, size):
|
|||||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
|
||||||
|
"""
|
||||||
|
Compute a bounding box around the union of all mask frames.
|
||||||
|
|
||||||
|
mask: [M, H', W'] float [0,1]
|
||||||
|
square: if True, expand bbox to a square and shift into frame bounds;
|
||||||
|
if False, apply margin independently on each axis (rect crop).
|
||||||
|
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
|
||||||
|
"""
|
||||||
|
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
|
||||||
|
m = F.interpolate(
|
||||||
|
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
|
||||||
|
).squeeze(1)
|
||||||
|
else:
|
||||||
|
m = mask.float()
|
||||||
|
|
||||||
|
union = (m > 0.5).max(dim=0).values # [H, W] bool
|
||||||
|
|
||||||
|
if not union.any():
|
||||||
|
if square:
|
||||||
|
# Empty mask — center square crop
|
||||||
|
side = min(frame_h, frame_w)
|
||||||
|
cy, cx = frame_h // 2, frame_w // 2
|
||||||
|
y0 = max(0, cy - side // 2)
|
||||||
|
x0 = max(0, cx - side // 2)
|
||||||
|
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
|
||||||
|
else:
|
||||||
|
# Empty mask — return full frame (no meaningful rect to crop to)
|
||||||
|
return 0, 0, frame_h, frame_w
|
||||||
|
|
||||||
|
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
|
||||||
|
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
|
||||||
|
y0, y1 = int(ys[0]), int(ys[-1]) + 1
|
||||||
|
x0, x1 = int(xs[0]), int(xs[-1]) + 1
|
||||||
|
|
||||||
|
if square:
|
||||||
|
side = max(y1 - y0, x1 - x0)
|
||||||
|
pad = int(side * margin)
|
||||||
|
side += 2 * pad
|
||||||
|
|
||||||
|
cy = (y0 + y1) // 2
|
||||||
|
cx = (x0 + x1) // 2
|
||||||
|
y0n = cy - side // 2
|
||||||
|
x0n = cx - side // 2
|
||||||
|
y1n = y0n + side
|
||||||
|
x1n = x0n + side
|
||||||
|
|
||||||
|
# Shift into frame bounds to preserve square shape
|
||||||
|
if y0n < 0: y1n -= y0n; y0n = 0
|
||||||
|
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
|
||||||
|
if x0n < 0: x1n -= x0n; x0n = 0
|
||||||
|
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
|
||||||
|
|
||||||
|
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
|
||||||
|
else:
|
||||||
|
pad_y = int(max(1, y1 - y0) * margin)
|
||||||
|
pad_x = int(max(1, x1 - x0) * margin)
|
||||||
|
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
|
||||||
|
|
||||||
|
|
||||||
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
||||||
"""
|
"""
|
||||||
Apply a ComfyUI MASK to resized frames.
|
Apply a ComfyUI MASK to resized frames.
|
||||||
@@ -68,20 +128,9 @@ def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
|||||||
return frames * alpha + 0.5 * (1.0 - alpha)
|
return frames * alpha + 0.5 * (1.0 - alpha)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_named_path(cache_dir: str, name: str) -> str:
|
|
||||||
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
|
|
||||||
# Sanitize: replace path separators so the name stays inside cache_dir
|
|
||||||
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
|
|
||||||
i = 1
|
|
||||||
while True:
|
|
||||||
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
|
|
||||||
if not os.path.exists(p):
|
|
||||||
return p
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||||
|
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||||
h = hashlib.sha256()
|
h = hashlib.sha256()
|
||||||
raw = video_tensor.cpu().numpy().tobytes()
|
raw = video_tensor.cpu().numpy().tobytes()
|
||||||
n = len(raw)
|
n = len(raw)
|
||||||
@@ -99,6 +148,10 @@ def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
|||||||
h.update(str(round(mask_strength, 4)).encode())
|
h.update(str(round(mask_strength, 4)).encode())
|
||||||
h.update(str(mask_clip).encode())
|
h.update(str(mask_clip).encode())
|
||||||
h.update(str(mask_sync).encode())
|
h.update(str(mask_sync).encode())
|
||||||
|
h.update(str(crop_to_mask).encode())
|
||||||
|
h.update(str(crop_rect).encode())
|
||||||
|
if crop_to_mask or crop_rect:
|
||||||
|
h.update(str(round(crop_margin, 4)).encode())
|
||||||
h.update(prompt.encode())
|
h.update(prompt.encode())
|
||||||
h.update(str(fps).encode())
|
h.update(str(fps).encode())
|
||||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||||
@@ -128,8 +181,6 @@ class SelvaFeatureExtractor:
|
|||||||
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
||||||
"cache_dir": ("STRING", {"default": "",
|
"cache_dir": ("STRING", {"default": "",
|
||||||
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
||||||
"name": ("STRING", {"default": "",
|
|
||||||
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
|
|
||||||
"mask": ("MASK", {
|
"mask": ("MASK", {
|
||||||
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
||||||
}),
|
}),
|
||||||
@@ -145,6 +196,18 @@ class SelvaFeatureExtractor:
|
|||||||
"default": True,
|
"default": True,
|
||||||
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
||||||
}),
|
}),
|
||||||
|
"crop_to_mask": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
|
||||||
|
}),
|
||||||
|
"crop_rect": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
|
||||||
|
}),
|
||||||
|
"crop_margin": ("FLOAT", {
|
||||||
|
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,14 +218,14 @@ class SelvaFeatureExtractor:
|
|||||||
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
||||||
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
||||||
)
|
)
|
||||||
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
|
|
||||||
FUNCTION = "extract_features"
|
FUNCTION = "extract_features"
|
||||||
CATEGORY = SELVA_CATEGORY
|
CATEGORY = SELVA_CATEGORY
|
||||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||||
|
|
||||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||||
duration=0.0, cache_dir="", name="", mask=None,
|
duration=0.0, cache_dir="", mask=None,
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||||
|
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||||
if video_info is not None:
|
if video_info is not None:
|
||||||
fps = video_info["loaded_fps"]
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
@@ -178,15 +241,11 @@ class SelvaFeatureExtractor:
|
|||||||
if not cache_dir:
|
if not cache_dir:
|
||||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
if name.strip():
|
|
||||||
# Named mode: always extract and save to an incremented filename
|
|
||||||
cached_path = _resolve_named_path(cache_dir, name.strip())
|
|
||||||
else:
|
|
||||||
# Hash mode: skip extraction if identical inputs were already processed
|
|
||||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
||||||
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
|
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
|
||||||
|
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
|
||||||
if os.path.exists(cached_path):
|
if os.path.exists(cached_path):
|
||||||
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||||
cached = _load_cached(cached_path)
|
cached = _load_cached(cached_path)
|
||||||
@@ -206,10 +265,24 @@ class SelvaFeatureExtractor:
|
|||||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||||
pbar = comfy.utils.ProgressBar(3)
|
pbar = comfy.utils.ProgressBar(3)
|
||||||
|
|
||||||
|
# Pre-compute crop bbox once from the original-resolution mask
|
||||||
|
crop_bbox = None
|
||||||
|
if mask is not None and (crop_to_mask or crop_rect):
|
||||||
|
H_vid, W_vid = video.shape[1], video.shape[2]
|
||||||
|
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
|
||||||
|
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
_mode = "square" if _square else "rect"
|
||||||
|
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
|
||||||
|
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||||
|
if crop_bbox is not None:
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
|
||||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||||
if mask is not None and mask_clip:
|
if mask is not None and mask_clip:
|
||||||
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
||||||
@@ -222,6 +295,9 @@ class SelvaFeatureExtractor:
|
|||||||
|
|
||||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||||
|
if crop_bbox is not None:
|
||||||
|
cy0, cx0, cy1, cx1 = crop_bbox
|
||||||
|
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
|
||||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||||
if mask is not None and mask_sync:
|
if mask is not None and mask_sync:
|
||||||
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
||||||
|
|||||||
@@ -1,421 +0,0 @@
|
|||||||
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "eval_batch_1",
|
|
||||||
"data_dir": "/path/to/features",
|
|
||||||
"output_dir": "/path/to/evals/batch1",
|
|
||||||
"steps": 25,
|
|
||||||
"seed": 42,
|
|
||||||
"adapters": [
|
|
||||||
{"id": "baseline"},
|
|
||||||
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
|
|
||||||
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
Empty / missing "path" = baseline (no LoRA applied).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_eval_sample,
|
|
||||||
_spectral_metrics,
|
|
||||||
_save_spectrogram,
|
|
||||||
_pil_to_tensor,
|
|
||||||
_find_audio,
|
|
||||||
_load_audio,
|
|
||||||
)
|
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
|
||||||
|
|
||||||
|
|
||||||
def _avg_metrics(metrics_list: list) -> dict:
|
|
||||||
"""Average spectral metrics across multiple clips, ignoring None entries."""
|
|
||||||
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
|
|
||||||
"spectral_flatness", "temporal_variance"]
|
|
||||||
valid = [m for m in metrics_list if m]
|
|
||||||
if not valid:
|
|
||||||
return {}
|
|
||||||
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_stem(adapter_id: str) -> str:
|
|
||||||
"""Replace characters illegal in filenames."""
|
|
||||||
for ch in r'/\:*?"<>|':
|
|
||||||
adapter_id = adapter_id.replace(ch, "_")
|
|
||||||
return adapter_id
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
|
|
||||||
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
|
|
||||||
|
|
||||||
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
|
|
||||||
"""
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
||||||
|
|
||||||
METRICS = [
|
|
||||||
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
|
|
||||||
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
|
|
||||||
("spectral_flatness", "Spectral Flatness"),
|
|
||||||
("temporal_variance", "Temporal Variance"),
|
|
||||||
]
|
|
||||||
COLORS = [
|
|
||||||
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
|
|
||||||
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
|
|
||||||
]
|
|
||||||
|
|
||||||
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
|
|
||||||
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
|
|
||||||
|
|
||||||
for ax, (key, title) in zip(axes, METRICS):
|
|
||||||
values = []
|
|
||||||
colors = []
|
|
||||||
for i, m in enumerate(metrics_list):
|
|
||||||
v = m.get(key, 0.0) if m else 0.0
|
|
||||||
values.append(v)
|
|
||||||
colors.append(COLORS[i % len(COLORS)])
|
|
||||||
|
|
||||||
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
|
|
||||||
ax.set_title(title, fontsize=9)
|
|
||||||
ax.set_xlabel(key, fontsize=8)
|
|
||||||
ax.tick_params(axis="y", labelsize=7)
|
|
||||||
ax.tick_params(axis="x", labelsize=7)
|
|
||||||
|
|
||||||
# Value labels on bars
|
|
||||||
for bar, val in zip(bars, values):
|
|
||||||
w = bar.get_width()
|
|
||||||
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
|
|
||||||
f"{val:.3f}", va="center", ha="left", fontsize=6)
|
|
||||||
|
|
||||||
canvas = FigureCanvasAgg(fig)
|
|
||||||
canvas.draw()
|
|
||||||
canvas.print_figure(str(output_path), dpi=110)
|
|
||||||
|
|
||||||
buf = canvas.buffer_rgba()
|
|
||||||
w, h = canvas.get_width_height()
|
|
||||||
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
|
|
||||||
from PIL import Image
|
|
||||||
return _pil_to_tensor(Image.fromarray(arr))
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraEvaluator:
|
|
||||||
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
|
|
||||||
|
|
||||||
Generates one audio sample per adapter, computes spectral metrics for each,
|
|
||||||
and produces a comparison chart. Use this after a sweep to compare candidates
|
|
||||||
before running the next round of training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_image")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to eval_summary.json — contains spectral metrics per adapter.",
|
|
||||||
"Bar chart comparing spectral metrics across all evaluated adapters.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
|
|
||||||
"from a fixed reference clip, then collects spectral metrics for comparison. "
|
|
||||||
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"eval_file": ("STRING", {
|
|
||||||
"default": "eval_batch.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to the JSON evaluation spec. Relative paths resolve "
|
|
||||||
"to the ComfyUI output directory. "
|
|
||||||
"Each adapter entry needs an 'id' and an optional 'path'. "
|
|
||||||
"Omit 'path' for a no-LoRA baseline."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, eval_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Resolve and parse the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
eval_path = Path(eval_file.strip())
|
|
||||||
if not eval_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / eval_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / eval_path
|
|
||||||
eval_path = candidate
|
|
||||||
if not eval_path.exists():
|
|
||||||
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
|
|
||||||
|
|
||||||
spec = json.loads(eval_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "adapters" not in spec or not spec["adapters"]:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
|
|
||||||
for i, a in enumerate(spec["adapters"]):
|
|
||||||
if "id" not in a:
|
|
||||||
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
|
|
||||||
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
|
|
||||||
if "output_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
|
|
||||||
|
|
||||||
name = spec.get("name", eval_path.stem)
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_dir = _resolve_path(spec["output_dir"])
|
|
||||||
steps = int(spec.get("steps", 25))
|
|
||||||
seed = int(spec.get("seed", 42))
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
|
|
||||||
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Prepare dataset (VAE encode once)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Collect reference metrics for all dataset clips
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
import shutil
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
ref_dir = output_dir / "reference"
|
|
||||||
ref_dir.mkdir(exist_ok=True)
|
|
||||||
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
|
|
||||||
|
|
||||||
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
|
|
||||||
flush=True)
|
|
||||||
for npz_path in npz_files:
|
|
||||||
audio_path = _find_audio(npz_path)
|
|
||||||
if audio_path is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
|
||||||
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
|
||||||
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
|
|
||||||
shutil.copy2(str(audio_path), str(ref_out))
|
|
||||||
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
|
||||||
ref_clips.append({
|
|
||||||
"clip": npz_path.stem,
|
|
||||||
"wav_path": str(ref_out),
|
|
||||||
"spectral_metrics": metrics,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
|
|
||||||
|
|
||||||
# Average reference metrics across all clips
|
|
||||||
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
|
|
||||||
print(f"[LoRA Evaluator] Reference avg — "
|
|
||||||
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
|
||||||
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
|
|
||||||
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build summary skeleton
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary = {
|
|
||||||
"name": name,
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
"n_clips": len(ref_clips),
|
|
||||||
"steps": steps,
|
|
||||||
"seed": seed,
|
|
||||||
"reference_avg": ref_avg,
|
|
||||||
"reference_clips": ref_clips,
|
|
||||||
"adapters": [],
|
|
||||||
}
|
|
||||||
summary_path = output_dir / "eval_summary.json"
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Per-adapter evaluation loop (all clips)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(dataset)
|
|
||||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
|
|
||||||
|
|
||||||
for adapter_spec in spec["adapters"]:
|
|
||||||
adapter_id = adapter_spec["id"]
|
|
||||||
adapter_path = (adapter_spec.get("path") or "").strip()
|
|
||||||
safe_id = _safe_stem(adapter_id)
|
|
||||||
clip_dir = output_dir / safe_id
|
|
||||||
clip_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
record = {
|
|
||||||
"id": adapter_id,
|
|
||||||
"path": adapter_path or None,
|
|
||||||
"meta": None,
|
|
||||||
"clips": [],
|
|
||||||
"avg_metrics": None,
|
|
||||||
"status": "running",
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False):
|
|
||||||
generator = copy.deepcopy(model["generator"])
|
|
||||||
|
|
||||||
if adapter_path:
|
|
||||||
pt_path = Path(adapter_path)
|
|
||||||
if not pt_path.is_absolute():
|
|
||||||
pt_path = Path(folder_paths.base_path) / pt_path
|
|
||||||
if not pt_path.exists():
|
|
||||||
raise FileNotFoundError(f"Adapter not found: {pt_path}")
|
|
||||||
|
|
||||||
ckpt = torch.load(str(pt_path), map_location="cpu",
|
|
||||||
weights_only=False)
|
|
||||||
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
|
||||||
state_dict = ckpt["state_dict"]
|
|
||||||
meta = ckpt.get("meta", {})
|
|
||||||
else:
|
|
||||||
state_dict = ckpt
|
|
||||||
meta = {}
|
|
||||||
|
|
||||||
rank = int(meta.get("rank", 16))
|
|
||||||
alpha = float(meta.get("alpha", float(rank)))
|
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
|
||||||
dropout = float(meta.get("lora_dropout", 0.0))
|
|
||||||
use_rslora = meta.get("use_rslora", False)
|
|
||||||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
|
||||||
|
|
||||||
# Always use standard init for loading — PiSSA checkpoints
|
|
||||||
# include linear.weight (residual) in state_dict
|
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
|
||||||
target_suffixes=tuple(target), dropout=dropout,
|
|
||||||
init_mode="standard", use_rslora=use_rslora)
|
|
||||||
if n == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"apply_lora matched 0 layers (target={target})"
|
|
||||||
)
|
|
||||||
load_lora(generator, state_dict)
|
|
||||||
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
|
|
||||||
f"(rank={rank}, {n} layers)", flush=True)
|
|
||||||
else:
|
|
||||||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
|
||||||
|
|
||||||
generator = generator.to(device, dtype)
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=seq_cfg.clip_seq_len,
|
|
||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_metrics_list = []
|
|
||||||
for clip_idx in range(n_clips):
|
|
||||||
clip_stem = npz_files[clip_idx].stem
|
|
||||||
wav, sr = _eval_sample(
|
|
||||||
generator, feature_utils_orig, dataset,
|
|
||||||
seq_cfg, device, dtype,
|
|
||||||
num_steps=steps, seed=seed, clip_idx=clip_idx,
|
|
||||||
)
|
|
||||||
if wav is None:
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav_path = clip_dir / f"{clip_stem}.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
|
||||||
|
|
||||||
metrics = _spectral_metrics(wav, sr)
|
|
||||||
clip_metrics_list.append(metrics)
|
|
||||||
record["clips"].append({
|
|
||||||
"clip": clip_stem,
|
|
||||||
"wav_path": str(wav_path),
|
|
||||||
"spectral_metrics": metrics,
|
|
||||||
})
|
|
||||||
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
|
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
|
||||||
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
|
|
||||||
record["status"] = "completed"
|
|
||||||
avg = record["avg_metrics"]
|
|
||||||
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
|
|
||||||
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
|
||||||
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
|
|
||||||
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
record["status"] = "failed"
|
|
||||||
record["error"] = str(e)
|
|
||||||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
pbar.update(n_clips - len(record["clips"]))
|
|
||||||
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
del generator
|
|
||||||
except NameError:
|
|
||||||
pass
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
summary["adapters"].append(record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Comparison chart
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
|
||||||
if completed:
|
|
||||||
ids = ["reference"] + [r["id"] for r in completed]
|
|
||||||
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
|
|
||||||
chart_path = output_dir / "metric_comparison.png"
|
|
||||||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
|
||||||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
|
||||||
else:
|
|
||||||
from PIL import Image
|
|
||||||
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
|
|
||||||
|
|
||||||
return (str(summary_path), comparison)
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
import copy
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraLoader:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"adapter_path": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
|
|
||||||
"tooltip": "Scale applied to all LoRA contributions. "
|
|
||||||
"1.0 = full adapter strength. "
|
|
||||||
"0.0 = effectively disables the adapter. "
|
|
||||||
"Values above 1.0 exaggerate the effect.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("SELVA_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
|
|
||||||
"The base model is not modified — a shallow copy of the model bundle is returned."
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
|
|
||||||
if not adapter_path.strip():
|
|
||||||
raise ValueError("[SelVA LoRA] adapter_path is empty.")
|
|
||||||
|
|
||||||
# Resolve path: allow absolute or relative to ComfyUI base
|
|
||||||
from pathlib import Path
|
|
||||||
p = Path(adapter_path)
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.base_path) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
|
|
||||||
|
|
||||||
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
# Support both raw state_dict and {state_dict, meta} formats
|
|
||||||
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["state_dict"]
|
|
||||||
meta = checkpoint.get("meta", {})
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
meta = {}
|
|
||||||
|
|
||||||
rank = int(meta.get("rank", 16))
|
|
||||||
alpha = float(meta.get("alpha", float(rank)))
|
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
|
||||||
init_mode = meta.get("init_mode", "standard")
|
|
||||||
use_rslora = meta.get("use_rslora", False)
|
|
||||||
|
|
||||||
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
|
||||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
|
|
||||||
f"init={init_mode} rslora={use_rslora} strength={strength}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Shallow-copy the model bundle so the original generator is not mutated
|
|
||||||
patched = {**model}
|
|
||||||
generator = copy.deepcopy(model["generator"])
|
|
||||||
|
|
||||||
# For PiSSA, use standard init (the base weights will be overwritten
|
|
||||||
# by load_state_dict since the checkpoint includes linear.weight)
|
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
|
||||||
target_suffixes=tuple(target),
|
|
||||||
init_mode="standard", use_rslora=use_rslora)
|
|
||||||
if n == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[SelVA LoRA] No layers matched target={target}. "
|
|
||||||
"Check that the adapter was trained with the same target suffixes."
|
|
||||||
)
|
|
||||||
load_lora(generator, state_dict)
|
|
||||||
|
|
||||||
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
|
|
||||||
norms = [p.norm().item() for name, p in generator.named_parameters()
|
|
||||||
if "lora_A" in name]
|
|
||||||
if norms:
|
|
||||||
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
|
|
||||||
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
|
|
||||||
else:
|
|
||||||
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
|
|
||||||
|
|
||||||
# Apply strength scaling: multiply all lora_B params by strength
|
|
||||||
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
|
|
||||||
if strength != 1.0:
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, param in generator.named_parameters():
|
|
||||||
if "lora_B" in name:
|
|
||||||
param.mul_(strength)
|
|
||||||
|
|
||||||
generator.to(model["generator"].parameters().__next__().device)
|
|
||||||
patched["generator"] = generator
|
|
||||||
|
|
||||||
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
|
|
||||||
return (patched,)
|
|
||||||
@@ -1,539 +0,0 @@
|
|||||||
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
The dataset is loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image showing all runs on the same axes.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "tier1_sweep",
|
|
||||||
"description": "optional human note",
|
|
||||||
"data_dir": "dataset/dog_bark",
|
|
||||||
"output_root": "lora_output/tier1_sweep",
|
|
||||||
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "..."},
|
|
||||||
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
SelvaLoraTrainer,
|
|
||||||
SkipExperiment,
|
|
||||||
_prepare_dataset,
|
|
||||||
_smooth_losses,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_system_info() -> dict:
|
|
||||||
"""Collect GPU / torch version info for the summary header."""
|
|
||||||
info: dict = {
|
|
||||||
"torch_version": torch.__version__,
|
|
||||||
"cuda_version": torch.version.cuda or "N/A",
|
|
||||||
"gpu_name": None,
|
|
||||||
"gpu_vram_gb": None,
|
|
||||||
}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
info["gpu_name"] = torch.cuda.get_device_name(0)
|
|
||||||
props = torch.cuda.get_device_properties(0)
|
|
||||||
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"alpha": 0.0,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 500,
|
|
||||||
"resume_path": "",
|
|
||||||
"seed": 42,
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant",
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": True,
|
|
||||||
"latent_mixup_alpha": 0.0,
|
|
||||||
"latent_noise_sigma": 0.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Palette for comparison chart: one color per experiment (cycles if > 8)
|
|
||||||
_PALETTE = [
|
|
||||||
(66, 133, 244), # blue
|
|
||||||
(234, 67, 53), # red
|
|
||||||
(52, 168, 83), # green
|
|
||||||
(251, 188, 5), # yellow
|
|
||||||
(155, 89, 182), # purple
|
|
||||||
(26, 188, 156), # teal
|
|
||||||
(230, 126, 34), # orange
|
|
||||||
(149, 165, 166), # grey
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = (
|
|
||||||
sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
)
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
"""Merge base defaults + file base + experiment overrides."""
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
# Don't carry id/description into the training params
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
start_step: int, total_steps: int) -> dict:
|
|
||||||
"""Build a dict of {step: loss} at each save_every boundary.
|
|
||||||
|
|
||||||
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
|
|
||||||
start + (i+1)*log_interval].
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
targets = range(save_every, total_steps + 1, save_every)
|
|
||||||
for target in targets:
|
|
||||||
# index of the loss entry nearest to this step
|
|
||||||
idx = (target - start_step) // log_interval - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_comparison_curves(
|
|
||||||
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
|
|
||||||
) -> Image.Image:
|
|
||||||
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
|
|
||||||
W, H = 900, 420
|
|
||||||
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
|
|
||||||
|
|
||||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
|
|
||||||
pw = W - pl - pr
|
|
||||||
ph = H - pt - pb
|
|
||||||
|
|
||||||
# Collect all smoothed series
|
|
||||||
series = []
|
|
||||||
for i, ed in enumerate(experiments_data):
|
|
||||||
lh = ed.get("loss_history") or []
|
|
||||||
if len(lh) < 2:
|
|
||||||
continue
|
|
||||||
sm = _smooth_losses(lh)
|
|
||||||
series.append({
|
|
||||||
"id": ed["id"],
|
|
||||||
"smoothed": sm,
|
|
||||||
"log_interval": ed.get("log_interval", 50),
|
|
||||||
"start_step": ed.get("start_step", 0),
|
|
||||||
"color": _PALETTE[i % len(_PALETTE)],
|
|
||||||
})
|
|
||||||
|
|
||||||
if not series:
|
|
||||||
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
|
|
||||||
return img
|
|
||||||
|
|
||||||
all_vals = [v for s in series for v in s["smoothed"]]
|
|
||||||
lo, hi = min(all_vals), max(all_vals)
|
|
||||||
if hi == lo:
|
|
||||||
hi = lo + 1e-6
|
|
||||||
rng = hi - lo
|
|
||||||
|
|
||||||
# Horizontal grid + y-axis labels
|
|
||||||
for i in range(5):
|
|
||||||
y = pt + int(i * ph / 4)
|
|
||||||
val = hi - i * rng / 4
|
|
||||||
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
|
||||||
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
|
|
||||||
|
|
||||||
# Draw each curve
|
|
||||||
for s in series:
|
|
||||||
n = len(s["smoothed"])
|
|
||||||
pts = []
|
|
||||||
for j, v in enumerate(s["smoothed"]):
|
|
||||||
x = pl + int(j * pw / max(n - 1, 1))
|
|
||||||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
|
||||||
pts.append((x, y))
|
|
||||||
draw.line(pts, fill=s["color"], width=2)
|
|
||||||
|
|
||||||
# Axes
|
|
||||||
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
|
|
||||||
|
|
||||||
# Legend (right side)
|
|
||||||
lx = W - pr + 10
|
|
||||||
ly = pt
|
|
||||||
for s in series:
|
|
||||||
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
|
|
||||||
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
|
|
||||||
ly += 20
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraScheduler:
|
|
||||||
"""Runs a sweep of LoRA training experiments defined in a JSON file.
|
|
||||||
|
|
||||||
The dataset (VAE encoding + .npz loading) is performed once and shared
|
|
||||||
across all experiments. Each experiment deep-copies the generator and trains
|
|
||||||
independently. Results are written to `experiment_summary.json` after every
|
|
||||||
completed run so partial results are preserved if the sweep is interrupted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — share this file to compare runs.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
|
|
||||||
"The dataset is encoded once and reused across all experiments. "
|
|
||||||
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"models directory; absolute paths are used as-is. "
|
|
||||||
"See LORA_TRAINING.md for the file format."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
# Try relative to ComfyUI models dir first, then output dir
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
|
|
||||||
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[LoRA Scheduler] {description}", flush=True)
|
|
||||||
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load + encode dataset once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(list(data_dir.glob("*.npz")))
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build or restore the summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = [] # collected for comparison image
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
lh = rec["results"].get("loss_history", [])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": lh,
|
|
||||||
"log_interval": rec["results"].get("log_interval", 50),
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
# Restore the original summary, clear completed_at so it gets set again
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
|
|
||||||
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
|
|
||||||
flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": n_clips,
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Run each experiment
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
trainer = SelvaLoraTrainer()
|
|
||||||
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
|
||||||
log_interval = 50 # matches _train_inner
|
|
||||||
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
variant = model["variant"]
|
|
||||||
mode = model["mode"]
|
|
||||||
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
# Required training params
|
|
||||||
steps = int(cfg.get("steps", 2000))
|
|
||||||
rank = int(cfg.get("rank", 16))
|
|
||||||
lr = float(cfg.get("lr", 1e-4))
|
|
||||||
alpha = float(cfg.get("alpha", 0.0))
|
|
||||||
target = str(cfg.get("target", "attn.qkv"))
|
|
||||||
batch_size = int(cfg.get("batch_size", 4))
|
|
||||||
warmup = int(cfg.get("warmup_steps", 100))
|
|
||||||
grad_accum = int(cfg.get("grad_accum", 1))
|
|
||||||
save_every = int(cfg.get("save_every", 500))
|
|
||||||
resume_path = str(cfg.get("resume_path", ""))
|
|
||||||
seed = int(cfg.get("seed", 42))
|
|
||||||
ts_mode = str(cfg.get("timestep_mode", "uniform"))
|
|
||||||
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
|
|
||||||
curr_switch = float(cfg.get("curriculum_switch", 0.6))
|
|
||||||
dropout = float(cfg.get("lora_dropout", 0.0))
|
|
||||||
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
|
|
||||||
lr_schedule = str(cfg.get("lr_schedule", "constant"))
|
|
||||||
init_mode = str(cfg.get("init_mode", "pissa"))
|
|
||||||
use_rslora = bool(cfg.get("use_rslora", True))
|
|
||||||
alpha_val = alpha if alpha > 0.0 else float(2 * rank)
|
|
||||||
target_suffixes = tuple(target.strip().split())
|
|
||||||
|
|
||||||
output_dir = output_root / exp_id
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
|
|
||||||
"batch_size": batch_size, "warmup_steps": warmup,
|
|
||||||
"grad_accum": grad_accum, "save_every": save_every,
|
|
||||||
"seed": seed, "target": list(target_suffixes),
|
|
||||||
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
|
|
||||||
"curriculum_switch": curr_switch,
|
|
||||||
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
|
|
||||||
"lr_schedule": lr_schedule,
|
|
||||||
"init_mode": init_mode, "use_rslora": use_rslora,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"adapter_path": None,
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = trainer._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, variant, mode,
|
|
||||||
data_dir, output_dir, steps, rank, lr,
|
|
||||||
alpha_val, target_suffixes, batch_size, warmup,
|
|
||||||
grad_accum, save_every, resume_path, seed,
|
|
||||||
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
|
|
||||||
lr_schedule, init_mode, use_rslora,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
loss_history = r["loss_history"]
|
|
||||||
grad_norm_history = r.get("grad_norm_history", [])
|
|
||||||
spectral_metrics = r.get("spectral_metrics", {})
|
|
||||||
run_start_step = r.get("start_step", 0)
|
|
||||||
smoothed = _smooth_losses(loss_history) if loss_history else []
|
|
||||||
|
|
||||||
# Scalar summary metrics
|
|
||||||
final_loss = round(smoothed[-1], 6) if smoothed else None
|
|
||||||
min_loss = round(min(smoothed), 6) if smoothed else None
|
|
||||||
min_idx = smoothed.index(min(smoothed)) if smoothed else None
|
|
||||||
min_loss_step = (
|
|
||||||
run_start_step + (min_idx + 1) * log_interval
|
|
||||||
if min_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stability: std-dev of raw loss over last 25% of steps
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
last_q = loss_history[-quarter:]
|
|
||||||
loss_std_last_quarter = round(float(np.std(last_q)), 6)
|
|
||||||
else:
|
|
||||||
loss_std_last_quarter = None
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std_last_quarter,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval, save_every, run_start_step, steps
|
|
||||||
),
|
|
||||||
"loss_history": [round(v, 6) for v in loss_history],
|
|
||||||
"grad_norm_history": grad_norm_history,
|
|
||||||
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["adapter_path"] = r["adapter_path"]
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
except SkipExperiment as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
|
|
||||||
partial = getattr(e, "partial", {})
|
|
||||||
lh = partial.get("loss_history", [])
|
|
||||||
smoothed = _smooth_losses(lh) if lh else []
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "skipped",
|
|
||||||
"stopped_at_step": partial.get("stopped_at_step"),
|
|
||||||
"final_loss": round(smoothed[-1], 6) if smoothed else None,
|
|
||||||
"loss_history": [round(v, 6) for v in lh],
|
|
||||||
"grad_norm_history": partial.get("grad_norm_history", []),
|
|
||||||
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
# Continue to next experiment rather than aborting the whole sweep
|
|
||||||
continue
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Comparison image
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
comparison_img.save(str(output_root / "loss_comparison.png"))
|
|
||||||
comparison_tensor = _pil_to_tensor(comparison_img)
|
|
||||||
|
|
||||||
return (str(summary_path), comparison_tensor)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -149,7 +149,7 @@ class SelvaModelLoader:
|
|||||||
enable_conditions=True,
|
enable_conditions=True,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
bigvgan_vocoder_ckpt=bigvgan_path,
|
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||||
need_vae_encoder=True,
|
need_vae_encoder=False,
|
||||||
).to(device, dtype).eval()
|
).to(device, dtype).eval()
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
if strategy == "offload_to_cpu":
|
||||||
|
|||||||
+3
-107
@@ -3,7 +3,6 @@ import comfy.utils
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
from .selva_textual_inversion_trainer import _inject_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaSampler:
|
class SelvaSampler:
|
||||||
@@ -32,31 +31,9 @@ class SelvaSampler:
|
|||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"steering_vectors": ("STEERING_VECTORS", {
|
|
||||||
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
|
|
||||||
"Nudges each DiT block's hidden state toward the extracted pattern.",
|
|
||||||
}),
|
|
||||||
"steering_strength": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
|
|
||||||
"tooltip": "Scale applied to each steering vector before adding to block output. "
|
|
||||||
"Start around 0.1–0.3; higher values risk destabilizing the ODE.",
|
|
||||||
}),
|
|
||||||
"normalize": ("BOOLEAN", {
|
"normalize": ("BOOLEAN", {
|
||||||
"default": True,
|
"default": True,
|
||||||
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
|
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
|
||||||
}),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0,
|
|
||||||
"tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.",
|
|
||||||
}),
|
|
||||||
"textual_inversion": ("TEXTUAL_INVERSION", {
|
|
||||||
"tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. "
|
|
||||||
"Injects style tokens into CLIP conditioning without modifying model weights.",
|
|
||||||
}),
|
|
||||||
"ti_strength": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Blends between original CLIP conditioning (0.0) and full TI injection (1.0). "
|
|
||||||
"Reduce toward 0.3–0.5 if TI produces buzz artifacts.",
|
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -68,7 +45,7 @@ class SelvaSampler:
|
|||||||
CATEGORY = SELVA_CATEGORY
|
CATEGORY = SELVA_CATEGORY
|
||||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||||
|
|
||||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
@@ -133,19 +110,6 @@ class SelvaSampler:
|
|||||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||||
if negative_prompt.strip() else None
|
if negative_prompt.strip() else None
|
||||||
|
|
||||||
# Inject textual inversion tokens into CLIP conditioning
|
|
||||||
if textual_inversion is not None:
|
|
||||||
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
|
|
||||||
K = emb.shape[0]
|
|
||||||
inject_mode = textual_inversion.get("inject_mode", "suffix")
|
|
||||||
ti_text = _inject_tokens(text_clip, emb, K, inject_mode)
|
|
||||||
text_clip = torch.lerp(text_clip, ti_text, ti_strength)
|
|
||||||
if neg_text_clip is not None:
|
|
||||||
ti_neg = _inject_tokens(neg_text_clip, emb, K, inject_mode)
|
|
||||||
neg_text_clip = torch.lerp(neg_text_clip, ti_neg, ti_strength)
|
|
||||||
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode} strength={ti_strength}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
empty_conditions = net_generator.get_empty_conditions(
|
empty_conditions = net_generator.get_empty_conditions(
|
||||||
bs=1, negative_text_features=neg_text_clip
|
bs=1, negative_text_features=neg_text_clip
|
||||||
@@ -159,63 +123,6 @@ class SelvaSampler:
|
|||||||
device=gen_device, dtype=dtype, generator=rng,
|
device=gen_device, dtype=dtype, generator=rng,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# Activation steering: apply only during the conditional predict_flow pass
|
|
||||||
# so steering gets amplified by cfg_strength rather than canceling out.
|
|
||||||
steering_handles = []
|
|
||||||
_orig_predict_flow = None
|
|
||||||
if steering_vectors is not None and steering_strength > 0.0:
|
|
||||||
vecs = steering_vectors["steering_vectors"]
|
|
||||||
n_joint = steering_vectors["n_joint"]
|
|
||||||
|
|
||||||
# Patch predict_flow to flag which pass is conditional.
|
|
||||||
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
|
|
||||||
# identity check tells us which is which.
|
|
||||||
_is_cond_pass = [False]
|
|
||||||
_orig_predict_flow = net_generator.predict_flow
|
|
||||||
|
|
||||||
def _tracked_predict_flow(latent, t, cond):
|
|
||||||
_is_cond_pass[0] = (cond is conditions)
|
|
||||||
return _orig_predict_flow(latent, t, cond)
|
|
||||||
|
|
||||||
net_generator.predict_flow = _tracked_predict_flow
|
|
||||||
|
|
||||||
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
|
||||||
vec = vec_cpu.to(dev, dt) # [seq, hidden]
|
|
||||||
def hook(module, input, output):
|
|
||||||
if not _is_cond_pass[0]:
|
|
||||||
return # skip unconditional pass; steering amplified by cfg_strength
|
|
||||||
# Interpolate steering vec to match actual output seq length
|
|
||||||
# (handles generation at different duration than extraction)
|
|
||||||
if is_joint:
|
|
||||||
out_seq = output[0].shape[1]
|
|
||||||
else:
|
|
||||||
out_seq = output.shape[1]
|
|
||||||
v = vec
|
|
||||||
if v.shape[0] != out_seq:
|
|
||||||
v = torch.nn.functional.interpolate(
|
|
||||||
v.T.unsqueeze(0), # [1, hidden, seq_orig]
|
|
||||||
size=out_seq,
|
|
||||||
mode="linear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0).T # [seq_new, hidden]
|
|
||||||
if is_joint:
|
|
||||||
latent_out = output[0] + strength * v
|
|
||||||
return (latent_out,) + output[1:]
|
|
||||||
else:
|
|
||||||
return output + strength * v
|
|
||||||
return hook
|
|
||||||
|
|
||||||
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
|
||||||
for i, block in enumerate(blocks):
|
|
||||||
is_joint = i < n_joint
|
|
||||||
if i < len(vecs):
|
|
||||||
h = block.register_forward_hook(
|
|
||||||
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
|
||||||
)
|
|
||||||
steering_handles.append(h)
|
|
||||||
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
|
|
||||||
f"strength={steering_strength} (conditional pass only)", flush=True)
|
|
||||||
|
|
||||||
# Flow matching ODE (Euler)
|
# Flow matching ODE (Euler)
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
@@ -232,11 +139,6 @@ class SelvaSampler:
|
|||||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
if _orig_predict_flow is not None:
|
|
||||||
net_generator.predict_flow = _orig_predict_flow
|
|
||||||
for h in steering_handles:
|
|
||||||
h.remove()
|
|
||||||
|
|
||||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||||
|
|
||||||
@@ -266,14 +168,8 @@ class SelvaSampler:
|
|||||||
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
target_rms = 10 ** (target_lufs / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = audio * (target_rms / rms)
|
|
||||||
# If RMS normalization pushes peaks into clipping, scale back to
|
|
||||||
# preserve dynamics rather than hard-clipping (no saturation)
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
if peak > 1.0:
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
audio = audio / peak
|
|
||||||
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||||
|
|
||||||
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaSkipExperiment:
|
|
||||||
"""Writes skip_current.flag into a sweep output_root.
|
|
||||||
|
|
||||||
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
|
|
||||||
the current experiment and move to the next one. The trainer picks up
|
|
||||||
the flag within 50 steps (~a few seconds).
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"output_root": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("flag_path",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
|
|
||||||
FUNCTION = "skip"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
|
|
||||||
"and move to the next one. Queue this node while the scheduler is running. "
|
|
||||||
"Partial scalars collected so far are saved in the summary."
|
|
||||||
)
|
|
||||||
|
|
||||||
def skip(self, output_root: str):
|
|
||||||
p = Path(output_root.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
|
|
||||||
|
|
||||||
flag = p / "skip_current.flag"
|
|
||||||
flag.touch()
|
|
||||||
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
|
|
||||||
return (str(flag),)
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Loader.
|
|
||||||
|
|
||||||
Loads a .pt file produced by SelvaTextualInversionTrainer and returns a
|
|
||||||
TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaTextualInversionLoader:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "textual_inversion.pt",
|
|
||||||
"tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. "
|
|
||||||
"Relative paths resolve to the ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("TEXTUAL_INVERSION",)
|
|
||||||
RETURN_NAMES = ("textual_inversion",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. "
|
|
||||||
"Connect the output to the SelVA Sampler's optional textual_inversion input to guide "
|
|
||||||
"generation toward the training data style without degrading audio quality."
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, path: str) -> tuple:
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[TI Loader] File not found: {p}")
|
|
||||||
|
|
||||||
data = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
embeddings = data["embeddings"] # [K, 1024]
|
|
||||||
n_tokens = int(data.get("n_tokens", embeddings.shape[0]))
|
|
||||||
|
|
||||||
print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} "
|
|
||||||
f"shape={tuple(embeddings.shape)}", flush=True)
|
|
||||||
if data.get("init_text"):
|
|
||||||
print(f"[TI Loader] init_text='{data['init_text']}'", flush=True)
|
|
||||||
if data.get("step"):
|
|
||||||
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
|
|
||||||
f"lr={data.get('lr', '?')}", flush=True)
|
|
||||||
|
|
||||||
inject_mode = data.get("inject_mode", "suffix")
|
|
||||||
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
|
|
||||||
|
|
||||||
bundle = {
|
|
||||||
"embeddings": embeddings, # [K, 1024] float32 on CPU
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"path": str(p),
|
|
||||||
"init_text": data.get("init_text", ""),
|
|
||||||
}
|
|
||||||
return (bundle,)
|
|
||||||
@@ -1,450 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Trainer.
|
|
||||||
|
|
||||||
Learns K token embedding vectors in CLIP space that guide the base model
|
|
||||||
to generate audio in the style of the training clips — without modifying
|
|
||||||
any model weights.
|
|
||||||
|
|
||||||
Key difference from LoRA:
|
|
||||||
- ALL generator parameters are frozen (requires_grad=False)
|
|
||||||
- Only K×1024 token embeddings receive gradients
|
|
||||||
- Latents stay on the decoder's natural manifold → no quality degradation
|
|
||||||
- The learned tokens shift WHICH latents are generated, not HOW
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
1. Train on your .npz audio features
|
|
||||||
2. Load result with SelVA Textual Inversion Loader
|
|
||||||
3. Connect to SelVA Sampler optional input
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import random
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_eval_sample,
|
|
||||||
_spectral_metrics,
|
|
||||||
_save_spectrogram,
|
|
||||||
_smooth_losses,
|
|
||||||
_draw_loss_curve,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Eval helper with token injection
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
|
|
||||||
n_tokens: int, inject_mode: str) -> torch.Tensor:
|
|
||||||
"""Build a text_clip tensor with learned tokens injected.
|
|
||||||
|
|
||||||
inject_mode:
|
|
||||||
"suffix" — replace last n_tokens positions (EOS/padding zone)
|
|
||||||
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
|
|
||||||
|
|
||||||
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
|
|
||||||
Works for both training (tokens is a Parameter) and eval (tokens is detached).
|
|
||||||
"""
|
|
||||||
if inject_mode == "prefix":
|
|
||||||
bos = text_clip[:, :1, :].detach() # [B, 1, D]
|
|
||||||
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
|
||||||
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
|
|
||||||
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
|
|
||||||
else: # suffix (default)
|
|
||||||
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
|
|
||||||
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
|
||||||
return torch.cat([front, toks], dim=1) # [B, 77, D]
|
|
||||||
|
|
||||||
|
|
||||||
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
|
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
|
||||||
device, dtype, num_steps=25, seed=42, clip_idx=0):
|
|
||||||
"""Inference pass with learned tokens injected into text conditioning."""
|
|
||||||
generator.eval()
|
|
||||||
try:
|
|
||||||
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
|
||||||
text_clip = text_clip_cpu.to(device, dtype).clone()
|
|
||||||
|
|
||||||
emb = learned_tokens.detach().to(device, dtype)
|
|
||||||
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
|
|
||||||
|
|
||||||
rng = torch.Generator(device=device).manual_seed(seed)
|
|
||||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
|
||||||
device=device, dtype=dtype, generator=rng)
|
|
||||||
|
|
||||||
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
|
||||||
|
|
||||||
def velocity_fn(t, x):
|
|
||||||
return generator.forward(x, clip_f, sync_f, text_input,
|
|
||||||
t.reshape(1).to(device, dtype))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
|
||||||
x1_unnorm = generator.unnormalize(x1_pred)
|
|
||||||
|
|
||||||
tod = feature_utils_orig.tod
|
|
||||||
tod_orig_dev = next(tod.parameters()).device
|
|
||||||
tod.to(device)
|
|
||||||
try:
|
|
||||||
spec = feature_utils_orig.decode(x1_unnorm)
|
|
||||||
audio = feature_utils_orig.vocode(spec)
|
|
||||||
finally:
|
|
||||||
tod.to(tod_orig_dev)
|
|
||||||
|
|
||||||
audio = audio.float().cpu()
|
|
||||||
if audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
|
||||||
audio = audio.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
target_rms = 10 ** (-27.0 / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = (audio * (target_rms / rms))
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
if peak > 1.0:
|
|
||||||
audio = audio / peak
|
|
||||||
return audio.squeeze(0), seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Trainer] Eval sample failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
return None, None
|
|
||||||
finally:
|
|
||||||
generator.train()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Node
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class SelvaTextualInversionTrainer:
|
|
||||||
"""Learns K CLIP token embeddings that steer SelVA toward a target audio style.
|
|
||||||
|
|
||||||
Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor
|
|
||||||
receives gradients, keeping generated latents on the decoder's natural manifold
|
|
||||||
and preserving base model audio quality while shifting generation style.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "train"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("embeddings_path", "loss_curve")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
|
|
||||||
"Smoothed training loss curve.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Trains K learnable CLIP token embeddings against your audio dataset "
|
|
||||||
"with all model weights frozen. The tokens are then injected into the "
|
|
||||||
"sampler to guide generation toward the training data style without "
|
|
||||||
"degrading audio quality."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"data_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).",
|
|
||||||
}),
|
|
||||||
"output_path": ("STRING", {
|
|
||||||
"default": "textual_inversion.pt",
|
|
||||||
"tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
"n_tokens": ("INT", {
|
|
||||||
"default": 4, "min": 1, "max": 16,
|
|
||||||
"tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.",
|
|
||||||
}),
|
|
||||||
"steps": ("INT", {
|
|
||||||
"default": 3000, "min": 100, "max": 50000,
|
|
||||||
"tooltip": "Training steps. 3000 is a reasonable starting point.",
|
|
||||||
}),
|
|
||||||
"lr": ("FLOAT", {
|
|
||||||
"default": 2e-4, "min": 1e-5, "max": 1e-1, "step": 1e-5,
|
|
||||||
"tooltip": "Learning rate. 2e-4 matches the LoRA working regime. Higher LR (1e-3) causes token norm to drift without plateauing on small datasets.",
|
|
||||||
}),
|
|
||||||
"batch_size": ("INT", {
|
|
||||||
"default": 4, "min": 1, "max": 64,
|
|
||||||
"tooltip": "Clips sampled per training step. Smaller batch (4–8) gives more diverse gradients and helps token norm saturate rather than drift.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
"save_every": ("INT", {
|
|
||||||
"default": 1000, "min": 100, "max": 10000,
|
|
||||||
"tooltip": "Save a checkpoint and generate an eval sample every N steps.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"inject_mode": (["suffix", "prefix"], {
|
|
||||||
"default": "suffix",
|
|
||||||
"tooltip": (
|
|
||||||
"Where to inject the learned tokens in the 77-token CLIP sequence. "
|
|
||||||
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
|
|
||||||
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
"init_text": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
|
|
||||||
}),
|
|
||||||
"warmup_steps": ("INT", {
|
|
||||||
"default": 100, "min": 0, "max": 1000,
|
|
||||||
"tooltip": "Linear LR warmup steps.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
|
|
||||||
batch_size, seed, save_every,
|
|
||||||
inject_mode="suffix", init_text="", warmup_steps=100):
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
|
|
||||||
# --- Resolve paths ---
|
|
||||||
data_dir = Path(data_dir.strip())
|
|
||||||
if not data_dir.is_absolute():
|
|
||||||
data_dir = Path(folder_paths.models_dir) / data_dir
|
|
||||||
if not data_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}")
|
|
||||||
|
|
||||||
out_path = Path(output_path.strip())
|
|
||||||
if not out_path.is_absolute():
|
|
||||||
out_path = Path(folder_paths.get_output_directory()) / out_path
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True)
|
|
||||||
print(f"[TI Trainer] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[TI Trainer] output = {out_path}\n", flush=True)
|
|
||||||
|
|
||||||
# --- Load dataset (reuse LoRA trainer helper) ---
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# Training must run outside inference_mode so autograd works
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = self._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup_steps, seed, save_every, init_text, inject_mode,
|
|
||||||
)
|
|
||||||
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
|
|
||||||
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
|
|
||||||
return (r["embeddings_path"], _pil_to_tensor(curve_img))
|
|
||||||
|
|
||||||
def _train_inner(
|
|
||||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
|
|
||||||
):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
# --- Generator (frozen) ---
|
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
|
||||||
generator.requires_grad_(False)
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=seq_cfg.clip_seq_len,
|
|
||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Init learned tokens ---
|
|
||||||
# Call encode_text_clip outside the grad context (it has @inference_mode),
|
|
||||||
# grab values only (no grad needed), then wrap as nn.Parameter.
|
|
||||||
if init_text.strip():
|
|
||||||
with torch.no_grad():
|
|
||||||
init_embed = feature_utils_orig.encode_text_clip([init_text.strip()])
|
|
||||||
# Positions 1:1+n_tokens — after BOS, before EOS — have actual content
|
|
||||||
init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float()
|
|
||||||
if init_vals.shape[0] < n_tokens:
|
|
||||||
# Prompt was very short; pad remaining with small noise
|
|
||||||
pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02
|
|
||||||
init_vals = torch.cat([init_vals, pad], dim=0)
|
|
||||||
learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype))
|
|
||||||
print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1–{n_tokens})", flush=True)
|
|
||||||
else:
|
|
||||||
learned_tokens = torch.nn.Parameter(
|
|
||||||
torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02
|
|
||||||
)
|
|
||||||
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
|
|
||||||
|
|
||||||
# --- Measure CLIP token norm from the dataset (content positions 1–20) ---
|
|
||||||
# Learned tokens must stay within this range or the model treats them as
|
|
||||||
# out-of-distribution and produces buzz artifacts instead of style shift.
|
|
||||||
with torch.no_grad():
|
|
||||||
sample_norms = []
|
|
||||||
for item in dataset[:min(len(dataset), 20)]:
|
|
||||||
tc = item[3].squeeze(0) # [77, 1024]
|
|
||||||
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
|
|
||||||
clip_norm_ref = torch.cat(sample_norms).mean().item()
|
|
||||||
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
|
|
||||||
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
|
|
||||||
f"limit={clip_norm_limit:.4f}", flush=True)
|
|
||||||
|
|
||||||
# --- Optimizer + scheduler ---
|
|
||||||
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
|
|
||||||
|
|
||||||
def lr_lambda(s):
|
|
||||||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
|
||||||
|
|
||||||
# --- Checkpoint dir ---
|
|
||||||
ckpt_dir = out_path.parent / out_path.stem
|
|
||||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# --- Baseline sample (once, before any training) ---
|
|
||||||
print(f"[TI Trainer] Generating baseline sample...", flush=True)
|
|
||||||
baseline_wav, baseline_sr = _eval_sample(
|
|
||||||
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
|
|
||||||
)
|
|
||||||
if baseline_wav is not None:
|
|
||||||
baseline_path = ckpt_dir / "baseline.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
|
|
||||||
try:
|
|
||||||
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
|
|
||||||
|
|
||||||
# --- Training loop ---
|
|
||||||
generator.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
log_interval = 50
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
loss_history = []
|
|
||||||
running_loss = 0.0
|
|
||||||
|
|
||||||
print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True)
|
|
||||||
|
|
||||||
for step in range(1, steps + 1):
|
|
||||||
batch = random.choices(dataset, k=batch_size)
|
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
|
||||||
|
|
||||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
|
||||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
|
||||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
|
||||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
|
|
||||||
|
|
||||||
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
|
|
||||||
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
|
|
||||||
|
|
||||||
x1 = generator.normalize(x1)
|
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
|
||||||
x0 = torch.randn_like(x1)
|
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
|
||||||
|
|
||||||
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
|
|
||||||
loss = fm.loss(v_pred, x0, x1).mean()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Clamp token norm to CLIP manifold — prevents out-of-distribution
|
|
||||||
# embeddings that cause buzz artifacts instead of style shift.
|
|
||||||
with torch.no_grad():
|
|
||||||
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
|
||||||
scale = (clip_norm_limit / norms).clamp(max=1.0)
|
|
||||||
learned_tokens.data.mul_(scale)
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if step % log_interval == 0:
|
|
||||||
avg = running_loss / log_interval
|
|
||||||
loss_history.append(round(avg, 6))
|
|
||||||
running_loss = 0.0
|
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
|
||||||
norm = learned_tokens.norm(dim=-1).mean().item()
|
|
||||||
print(f"[TI Trainer] step {step:5d}/{steps} "
|
|
||||||
f"loss={avg:.4f} lr={lr_now:.2e} "
|
|
||||||
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
|
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
|
||||||
# Save checkpoint
|
|
||||||
ckpt = {
|
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"step": step,
|
|
||||||
"init_text": init_text,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
|
|
||||||
torch.save(ckpt, str(ckpt_path))
|
|
||||||
|
|
||||||
# Eval sample
|
|
||||||
wav, sr = _eval_sample_ti(
|
|
||||||
generator, learned_tokens, n_tokens, inject_mode,
|
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
|
||||||
device, dtype, seed=seed,
|
|
||||||
)
|
|
||||||
if wav is not None:
|
|
||||||
wav_path = ckpt_dir / f"step_{step:05d}.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
|
||||||
|
|
||||||
try:
|
|
||||||
metrics = _spectral_metrics(wav, sr)
|
|
||||||
_save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
|
|
||||||
print(f"[TI Trainer] step {step} "
|
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
|
||||||
f"flatness={metrics['spectral_flatness']:.4f} "
|
|
||||||
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True)
|
|
||||||
|
|
||||||
# --- Final save ---
|
|
||||||
final = {
|
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"step": steps,
|
|
||||||
"init_text": init_text,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
torch.save(final, str(out_path))
|
|
||||||
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
|
|
||||||
|
|
||||||
soft_empty_cache()
|
|
||||||
return {
|
|
||||||
"embeddings_path": str(out_path),
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
@@ -1,479 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
The dataset is loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image showing all runs on the same axes.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "ti_sweep_1",
|
|
||||||
"description": "optional human note",
|
|
||||||
"data_dir": "dataset/bj_sounds",
|
|
||||||
"output_root": "ti_output/sweep_1",
|
|
||||||
"base": {
|
|
||||||
"n_tokens": 4,
|
|
||||||
"lr": 1e-3,
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"seed": 42,
|
|
||||||
"save_every": 1000
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "default 4 tokens"},
|
|
||||||
{"id": "n8_tokens", "n_tokens": 8},
|
|
||||||
{"id": "lr_5e4", "lr": 5e-4},
|
|
||||||
{"id": "warm_init", "init_text": "industrial sound design"},
|
|
||||||
{"id": "n4_more_steps", "steps": 5000}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_smooth_losses,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_system_info() -> dict:
|
|
||||||
info: dict = {
|
|
||||||
"torch_version": torch.__version__,
|
|
||||||
"cuda_version": torch.version.cuda or "N/A",
|
|
||||||
"gpu_name": None,
|
|
||||||
"gpu_vram_gb": None,
|
|
||||||
}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
info["gpu_name"] = torch.cuda.get_device_name(0)
|
|
||||||
props = torch.cuda.get_device_properties(0)
|
|
||||||
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"n_tokens": 4,
|
|
||||||
"lr": 2e-4,
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"seed": 42,
|
|
||||||
"save_every": 1000,
|
|
||||||
"init_text": "",
|
|
||||||
"inject_mode": "suffix",
|
|
||||||
}
|
|
||||||
|
|
||||||
_PALETTE = [
|
|
||||||
(66, 133, 244),
|
|
||||||
(234, 67, 53),
|
|
||||||
(52, 168, 83),
|
|
||||||
(251, 188, 5),
|
|
||||||
(155, 89, 182),
|
|
||||||
(26, 188, 156),
|
|
||||||
(230, 126, 34),
|
|
||||||
(149, 165, 166),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = (
|
|
||||||
sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
)
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
total_steps: int) -> dict:
|
|
||||||
result = {}
|
|
||||||
for target in range(save_every, total_steps + 1, save_every):
|
|
||||||
idx = target // log_interval - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
|
|
||||||
W, H = 900, 420
|
|
||||||
pl, pr, pt, pb = 75, 160, 30, 50
|
|
||||||
|
|
||||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
pw = W - pl - pr
|
|
||||||
ph = H - pt - pb
|
|
||||||
|
|
||||||
series = []
|
|
||||||
for i, ed in enumerate(experiments_data):
|
|
||||||
lh = ed.get("loss_history") or []
|
|
||||||
if len(lh) < 2:
|
|
||||||
continue
|
|
||||||
sm = _smooth_losses(lh)
|
|
||||||
series.append({
|
|
||||||
"id": ed["id"],
|
|
||||||
"smoothed": sm,
|
|
||||||
"color": _PALETTE[i % len(_PALETTE)],
|
|
||||||
})
|
|
||||||
|
|
||||||
if not series:
|
|
||||||
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
|
|
||||||
return img
|
|
||||||
|
|
||||||
all_vals = [v for s in series for v in s["smoothed"]]
|
|
||||||
lo, hi = min(all_vals), max(all_vals)
|
|
||||||
if hi == lo:
|
|
||||||
hi = lo + 1e-6
|
|
||||||
rng = hi - lo
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
y = pt + int(i * ph / 4)
|
|
||||||
val = hi - i * rng / 4
|
|
||||||
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
|
||||||
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
|
|
||||||
|
|
||||||
for s in series:
|
|
||||||
n = len(s["smoothed"])
|
|
||||||
pts = []
|
|
||||||
for j, v in enumerate(s["smoothed"]):
|
|
||||||
x = pl + int(j * pw / max(n - 1, 1))
|
|
||||||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
|
||||||
pts.append((x, y))
|
|
||||||
draw.line(pts, fill=s["color"], width=2)
|
|
||||||
|
|
||||||
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
|
|
||||||
|
|
||||||
lx, ly = W - pr + 10, pt
|
|
||||||
for s in series:
|
|
||||||
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
|
|
||||||
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
|
|
||||||
ly += 20
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Node
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class SelvaTiScheduler:
|
|
||||||
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
|
|
||||||
|
|
||||||
The dataset is loaded once and reused. Each experiment calls
|
|
||||||
SelvaTextualInversionTrainer._train_inner() with its own config.
|
|
||||||
Results are written to experiment_summary.json after every completed run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — compare runs across sweeps.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
|
|
||||||
"The dataset is encoded once and reused. Results (loss, config, embeddings "
|
|
||||||
"paths) are collected in experiment_summary.json after each run."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "ti_experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"output directory. See node description for the file format."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate JSON
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[TI Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
|
|
||||||
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[TI Scheduler] {description}", flush=True)
|
|
||||||
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load dataset once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(list(data_dir.glob("*.npz")))
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build or restore summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": rec["results"].get("loss_history", []),
|
|
||||||
})
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
|
|
||||||
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
|
|
||||||
flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": n_clips,
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
comparison_img_path = output_root / "loss_comparison.png"
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
def _save_comparison():
|
|
||||||
try:
|
|
||||||
img = _draw_comparison_curves(all_curve_data)
|
|
||||||
img.save(str(comparison_img_path))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Run each experiment
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
trainer = SelvaTextualInversionTrainer()
|
|
||||||
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
|
||||||
log_interval = 50 # matches _train_inner
|
|
||||||
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
n_tokens = int(cfg["n_tokens"])
|
|
||||||
lr = float(cfg["lr"])
|
|
||||||
steps = int(cfg["steps"])
|
|
||||||
batch_size = int(cfg["batch_size"])
|
|
||||||
warmup = int(cfg["warmup_steps"])
|
|
||||||
seed = int(cfg["seed"])
|
|
||||||
save_every = int(cfg["save_every"])
|
|
||||||
init_text = str(cfg["init_text"])
|
|
||||||
inject_mode = str(cfg["inject_mode"])
|
|
||||||
|
|
||||||
output_dir = output_root / exp_id
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_path = output_dir / "embeddings.pt"
|
|
||||||
|
|
||||||
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[TI Scheduler] {exp_desc}", flush=True)
|
|
||||||
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
|
|
||||||
f"batch_size={batch_size} warmup={warmup} seed={seed} "
|
|
||||||
f"inject_mode={inject_mode}", flush=True)
|
|
||||||
if init_text:
|
|
||||||
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"warmup_steps": warmup,
|
|
||||||
"seed": seed,
|
|
||||||
"save_every": save_every,
|
|
||||||
"init_text": init_text,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"embeddings_path": None,
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = trainer._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup, seed, save_every, init_text, inject_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
loss_history = r["loss_history"]
|
|
||||||
smoothed = _smooth_losses(loss_history) if loss_history else []
|
|
||||||
|
|
||||||
final_loss = round(smoothed[-1], 6) if smoothed else None
|
|
||||||
min_loss = round(min(smoothed), 6) if smoothed else None
|
|
||||||
min_idx = smoothed.index(min(smoothed)) if smoothed else None
|
|
||||||
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
|
|
||||||
|
|
||||||
loss_std_last_quarter = None
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std_last_quarter,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval, save_every, steps
|
|
||||||
),
|
|
||||||
"loss_history": [round(v, 6) for v in loss_history],
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["embeddings_path"] = r["embeddings_path"]
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
_save_comparison()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Finalise
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Comparison image (final update, then return to ComfyUI)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
_save_comparison()
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
return (str(summary_path), _pil_to_tensor(comparison_img))
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
|
|
||||||
|
|
||||||
Useful for diagnosing codec reconstruction quality: if the output sounds
|
|
||||||
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
|
|
||||||
not the diffusion model or LoRA.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
|
|
||||||
|
|
||||||
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaVaeRoundtrip:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio_reconstructed",)
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Audio after VAE encode → decode roundtrip. "
|
|
||||||
"Compare to the input to hear codec reconstruction quality.",
|
|
||||||
)
|
|
||||||
FUNCTION = "roundtrip"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
|
|
||||||
"Use this to isolate codec reconstruction quality from generation quality. "
|
|
||||||
"If the output sounds degraded compared to the input, the VAE/DAC is the "
|
|
||||||
"bottleneck — not the model or LoRA."
|
|
||||||
)
|
|
||||||
|
|
||||||
def roundtrip(self, model, audio):
|
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
|
||||||
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
dtype = model["dtype"]
|
|
||||||
device = get_device()
|
|
||||||
generator = model["generator"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
|
|
||||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
|
||||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
|
||||||
if not vae_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
|
|
||||||
"Run SelVA Model Loader first to auto-download weights."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load encoder only — decoder/vocoder come from model["feature_utils"]
|
|
||||||
# to mirror exactly what the sampler uses.
|
|
||||||
# AutoEncoderModule requires vocoder_ckpt_path even when only encoding,
|
|
||||||
# so pass the BigVGAN path (weights won't actually be used for decode here).
|
|
||||||
bigvgan_path = _SELVA_DIR / "ext" / "best_netG.pt"
|
|
||||||
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
|
|
||||||
vae_enc = FeaturesUtils(
|
|
||||||
tod_vae_ckpt=str(vae_path),
|
|
||||||
enable_conditions=False,
|
|
||||||
mode=mode,
|
|
||||||
need_vae_encoder=True,
|
|
||||||
bigvgan_vocoder_ckpt=str(bigvgan_path) if bigvgan_path.exists() else None,
|
|
||||||
).to(device).eval()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Prepare input audio
|
|
||||||
waveform = audio["waveform"] # [1, C, L]
|
|
||||||
sr_in = audio["sample_rate"]
|
|
||||||
|
|
||||||
wav = waveform[0].mean(0) # mono [L]
|
|
||||||
|
|
||||||
if sr_in != seq_cfg.sampling_rate:
|
|
||||||
wav = torchaudio.functional.resample(
|
|
||||||
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
|
|
||||||
).squeeze(0)
|
|
||||||
print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
|
|
||||||
if wav.shape[0] > target_len:
|
|
||||||
wav = wav[:target_len]
|
|
||||||
elif wav.shape[0] < target_len:
|
|
||||||
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
|
||||||
|
|
||||||
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Encode: audio → raw latent [1, latent_dim, T]
|
|
||||||
dist = vae_enc.encode_audio(wav_b)
|
|
||||||
latent = dist.mode().clone()
|
|
||||||
|
|
||||||
# Trim/pad to exact model sequence length (same as _prepare_dataset)
|
|
||||||
tgt = seq_cfg.latent_seq_len
|
|
||||||
if latent.shape[2] < tgt:
|
|
||||||
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
|
||||||
elif latent.shape[2] > tgt:
|
|
||||||
latent = latent[:, :, :tgt]
|
|
||||||
|
|
||||||
# To [B, T, latent_dim] — layout the generator uses
|
|
||||||
latent_t = latent.transpose(1, 2).to(dtype)
|
|
||||||
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Normalize → unnormalize mirrors the training/inference pipeline:
|
|
||||||
# training normalizes encoded latents; sampler unnormalizes before decode.
|
|
||||||
# This ensures the latent is in the same space the decoder expects.
|
|
||||||
latent_norm = generator.normalize(latent_t.clone())
|
|
||||||
latent_unnorm = generator.unnormalize(latent_norm)
|
|
||||||
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Decode using model's feature_utils — same path as the sampler
|
|
||||||
tod = feature_utils.tod
|
|
||||||
tod_orig_device = next(tod.parameters()).device
|
|
||||||
tod.to(device)
|
|
||||||
try:
|
|
||||||
spec = feature_utils.decode(latent_unnorm)
|
|
||||||
out = feature_utils.vocode(spec)
|
|
||||||
finally:
|
|
||||||
tod.to(tod_orig_device)
|
|
||||||
|
|
||||||
out = out.float().cpu()
|
|
||||||
if out.dim() == 1:
|
|
||||||
out = out.unsqueeze(0).unsqueeze(0)
|
|
||||||
elif out.dim() == 2:
|
|
||||||
out = out.unsqueeze(1)
|
|
||||||
elif out.dim() == 3 and out.shape[1] != 1:
|
|
||||||
out = out.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
target_rms = 10 ** (-27.0 / 20.0)
|
|
||||||
out = out * (target_rms / rms)
|
|
||||||
out = out.clamp(-1.0, 1.0)
|
|
||||||
|
|
||||||
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
|
|
||||||
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
del vae_enc
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
|
|
||||||
@@ -1,309 +0,0 @@
|
|||||||
"""
|
|
||||||
LoRA (Low-Rank Adaptation) for SelVA / MMAudio generator.
|
|
||||||
|
|
||||||
Supports two initialization modes:
|
|
||||||
- **standard**: Kaiming-uniform A, zero B (classic LoRA).
|
|
||||||
- **pissa**: A and B from the top-r SVD of the pretrained weight.
|
|
||||||
Starts on-manifold, eliminates intruder dimensions at init
|
|
||||||
(arXiv:2404.02948, NeurIPS 2024 Spotlight).
|
|
||||||
|
|
||||||
Supports two scaling modes:
|
|
||||||
- **standard**: alpha / rank
|
|
||||||
- **rslora**: alpha / sqrt(rank) — rank-stabilized scaling that prevents
|
|
||||||
gradient collapse at high ranks (arXiv:2312.03732).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
|
||||||
|
|
||||||
n = apply_lora(net_generator, rank=16, alpha=16.0)
|
|
||||||
print(f"Wrapped {n} linear layers with LoRA")
|
|
||||||
|
|
||||||
# ... train only LoRA params ...
|
|
||||||
|
|
||||||
torch.save(get_lora_state_dict(net_generator), "adapter.pt")
|
|
||||||
|
|
||||||
# Later, at inference:
|
|
||||||
apply_lora(net_generator, rank=16, alpha=16.0)
|
|
||||||
load_lora(net_generator, torch.load("adapter.pt"))
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
|
||||||
"""nn.Linear with a frozen base weight and trainable low-rank A/B matrices.
|
|
||||||
|
|
||||||
Output: base(x) + (dropout(x) @ A.T @ B.T) * scale
|
|
||||||
|
|
||||||
Standard init: A is Kaiming uniform, B is zero → adapter starts at zero.
|
|
||||||
PiSSA init: A and B from top-r SVD of pretrained weight → adapter starts
|
|
||||||
at the principal components, base weight stores the residual.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float,
|
|
||||||
dropout: float = 0.0, init_mode: str = "standard",
|
|
||||||
use_rslora: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
in_f = linear.in_features
|
|
||||||
out_f = linear.out_features
|
|
||||||
|
|
||||||
self.linear = linear
|
|
||||||
linear.weight.requires_grad_(False)
|
|
||||||
if linear.bias is not None:
|
|
||||||
linear.bias.requires_grad_(False)
|
|
||||||
|
|
||||||
ref_dtype = linear.weight.dtype
|
|
||||||
ref_device = linear.weight.device
|
|
||||||
|
|
||||||
if use_rslora:
|
|
||||||
self.scale = alpha / math.sqrt(rank)
|
|
||||||
else:
|
|
||||||
self.scale = alpha / rank
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
|
|
||||||
|
|
||||||
if init_mode == "pissa":
|
|
||||||
# PiSSA: init from top-r SVD of pretrained weight.
|
|
||||||
# SVD in float32 for numerical stability, then cast back.
|
|
||||||
W = linear.weight.data.float() # [out_f, in_f]
|
|
||||||
U, S, Vt = torch.linalg.svd(W, full_matrices=False)
|
|
||||||
|
|
||||||
sqrt_S = S[:rank].sqrt()
|
|
||||||
# A: [rank, in_f], B: [out_f, rank]
|
|
||||||
A_init = sqrt_S.unsqueeze(1) * Vt[:rank, :]
|
|
||||||
B_init = U[:, :rank] * sqrt_S.unsqueeze(0)
|
|
||||||
|
|
||||||
# Residual: W_res = W - B_init @ A_init * scale
|
|
||||||
# so that base(x) + LoRA(x) = W_res@x + (B@A)*scale@x = W@x at init
|
|
||||||
linear.weight.data = (W - B_init @ A_init * self.scale).to(ref_dtype)
|
|
||||||
|
|
||||||
self.lora_A = nn.Parameter(A_init.to(dtype=ref_dtype, device=ref_device))
|
|
||||||
self.lora_B = nn.Parameter(B_init.to(dtype=ref_dtype, device=ref_device))
|
|
||||||
else:
|
|
||||||
# Standard LoRA: Kaiming A, zero B → starts at identity
|
|
||||||
self.lora_A = nn.Parameter(torch.empty(rank, in_f, dtype=ref_dtype, device=ref_device))
|
|
||||||
self.lora_B = nn.Parameter(torch.zeros(out_f, rank, dtype=ref_dtype, device=ref_device))
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.linear(x) + (self.dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scale
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
rank = self.lora_A.shape[0]
|
|
||||||
p = self.dropout.p if isinstance(self.dropout, nn.Dropout) else 0.0
|
|
||||||
return (f"in={self.linear.in_features}, out={self.linear.out_features}, "
|
|
||||||
f"rank={rank}, scale={self.scale:.4f}, dropout={p}")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_lora(
|
|
||||||
model: nn.Module,
|
|
||||||
rank: int = 16,
|
|
||||||
alpha: float = None,
|
|
||||||
target_suffixes: tuple = ("attn.qkv",),
|
|
||||||
dropout: float = 0.0,
|
|
||||||
init_mode: str = "standard",
|
|
||||||
use_rslora: bool = False,
|
|
||||||
) -> int:
|
|
||||||
"""Replace matching nn.Linear layers with LoRALinear in-place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The module to modify (typically net_generator).
|
|
||||||
rank: LoRA rank.
|
|
||||||
alpha: LoRA alpha (scaling). Defaults to rank (scale = 1.0).
|
|
||||||
target_suffixes: Tuple of module name suffixes to wrap. Default is
|
|
||||||
("attn.qkv",) which targets all SelfAttention QKV
|
|
||||||
projections in the MM-DiT generator.
|
|
||||||
Add "linear1" to also wrap post-attention output projections.
|
|
||||||
dropout: Dropout probability on the LoRA path (not the base linear).
|
|
||||||
0.05–0.1 helps regularize on small datasets.
|
|
||||||
Must be 0 when using PiSSA (principal components shouldn't be dropped).
|
|
||||||
init_mode: "standard" (Kaiming/zero) or "pissa" (SVD-based).
|
|
||||||
use_rslora: If True, scale by alpha/sqrt(rank) instead of alpha/rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of linear layers wrapped.
|
|
||||||
"""
|
|
||||||
if alpha is None:
|
|
||||||
alpha = float(rank)
|
|
||||||
|
|
||||||
if init_mode == "pissa" and dropout > 0.0:
|
|
||||||
print("[LoRA] Warning: dropout forced to 0 for PiSSA init "
|
|
||||||
"(principal components should not be dropped).")
|
|
||||||
dropout = 0.0
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for name, module in list(model.named_modules()):
|
|
||||||
if not any(name.endswith(s) for s in target_suffixes):
|
|
||||||
continue
|
|
||||||
if not isinstance(module, nn.Linear):
|
|
||||||
continue
|
|
||||||
|
|
||||||
parts = name.split(".")
|
|
||||||
parent = model
|
|
||||||
for part in parts[:-1]:
|
|
||||||
parent = getattr(parent, part)
|
|
||||||
setattr(parent, parts[-1], LoRALinear(
|
|
||||||
module, rank, alpha, dropout=dropout,
|
|
||||||
init_mode=init_mode, use_rslora=use_rslora,
|
|
||||||
))
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_state_dict(model: nn.Module) -> dict:
|
|
||||||
"""Return a state dict containing only LoRA parameters (lora_A and lora_B)."""
|
|
||||||
return {k: v for k, v in model.state_dict().items() if "lora_" in k}
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_and_base_state_dict(model: nn.Module) -> dict:
|
|
||||||
"""Return state dict with LoRA params AND base linear weights.
|
|
||||||
|
|
||||||
Needed for PiSSA checkpoints where the base weight stores the residual
|
|
||||||
(W - top_r(W)*scale), not the original pretrained weight.
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
prefix = name + "."
|
|
||||||
result[prefix + "lora_A"] = module.lora_A.data
|
|
||||||
result[prefix + "lora_B"] = module.lora_B.data
|
|
||||||
result[prefix + "linear.weight"] = module.linear.weight.data
|
|
||||||
if module.linear.bias is not None:
|
|
||||||
result[prefix + "linear.bias"] = module.linear.bias.data
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def spectral_surgery(
|
|
||||||
model: nn.Module,
|
|
||||||
calibration_fn,
|
|
||||||
n_calibration: int = 128,
|
|
||||||
policy: str = "smooth_abs",
|
|
||||||
):
|
|
||||||
"""Post-training Spectral Surgery: reweight LoRA singular values to suppress
|
|
||||||
intruder dimensions and amplify useful components (arXiv:2603.03995).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model with LoRA applied.
|
|
||||||
calibration_fn: Callable that takes (model, step_idx) and runs one forward+backward
|
|
||||||
pass on a calibration sample. Must call loss.backward().
|
|
||||||
n_calibration: Number of calibration samples to average gradients over.
|
|
||||||
policy: Reweighting policy: "smooth_abs" (recommended), "hard" (binary).
|
|
||||||
|
|
||||||
Modifies LoRA A and B in-place. Returns number of layers processed.
|
|
||||||
"""
|
|
||||||
model.eval()
|
|
||||||
lora_layers = [(name, mod) for name, mod in model.named_modules()
|
|
||||||
if isinstance(mod, LoRALinear)]
|
|
||||||
|
|
||||||
if not lora_layers:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Accumulate per-layer gradient sensitivity: g_k = u_k^T * (dL/dΔW) * v_k
|
|
||||||
sensitivities = {}
|
|
||||||
for name, mod in lora_layers:
|
|
||||||
sensitivities[name] = None
|
|
||||||
|
|
||||||
for step in range(n_calibration):
|
|
||||||
model.zero_grad()
|
|
||||||
# Enable grad temporarily on LoRA params
|
|
||||||
for _, mod in lora_layers:
|
|
||||||
mod.lora_A.requires_grad_(True)
|
|
||||||
mod.lora_B.requires_grad_(True)
|
|
||||||
|
|
||||||
calibration_fn(model, step)
|
|
||||||
|
|
||||||
for name, mod in lora_layers:
|
|
||||||
A = mod.lora_A.data.float() # [rank, in_f]
|
|
||||||
B = mod.lora_B.data.float() # [out_f, rank]
|
|
||||||
# ΔW = B @ A * scale → gradient dL/dΔW ≈ (dL/dB @ A + B^T @ dL/dA) / 2
|
|
||||||
# Per-component sensitivity: project onto SVD directions
|
|
||||||
delta_W = (B @ A * mod.scale).detach()
|
|
||||||
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
|
|
||||||
r = A.shape[0]
|
|
||||||
U_r, S_r, Vt_r = U[:, :r], S[:r], Vt[:r, :]
|
|
||||||
|
|
||||||
# Compute sensitivity from LoRA gradients
|
|
||||||
if mod.lora_A.grad is not None and mod.lora_B.grad is not None:
|
|
||||||
grad_A = mod.lora_A.grad.float() # [rank, in_f]
|
|
||||||
grad_B = mod.lora_B.grad.float() # [out_f, rank]
|
|
||||||
# dL/d(ΔW) ≈ grad_B @ A + B^T @ grad_A (chain rule through B@A)
|
|
||||||
grad_dW = grad_B @ A + B.T @ grad_A # approximate
|
|
||||||
# Per-component: g_k = u_k^T @ grad_dW @ v_k
|
|
||||||
g = torch.einsum("ik,ij,jk->k", U_r, grad_dW, Vt_r.T) # [r]
|
|
||||||
else:
|
|
||||||
g = torch.zeros(r, device=A.device)
|
|
||||||
|
|
||||||
if sensitivities[name] is None:
|
|
||||||
sensitivities[name] = g
|
|
||||||
else:
|
|
||||||
sensitivities[name] += g
|
|
||||||
|
|
||||||
# Disable grad again
|
|
||||||
for _, mod in lora_layers:
|
|
||||||
mod.lora_A.requires_grad_(False)
|
|
||||||
mod.lora_B.requires_grad_(False)
|
|
||||||
|
|
||||||
# Apply reweighting per layer
|
|
||||||
count = 0
|
|
||||||
for name, mod in lora_layers:
|
|
||||||
g = sensitivities[name] / n_calibration
|
|
||||||
A = mod.lora_A.data.float()
|
|
||||||
B = mod.lora_B.data.float()
|
|
||||||
|
|
||||||
delta_W = B @ A * mod.scale
|
|
||||||
U, S, Vt = torch.linalg.svd(delta_W, full_matrices=False)
|
|
||||||
r = A.shape[0]
|
|
||||||
S_r = S[:r]
|
|
||||||
|
|
||||||
if policy == "hard":
|
|
||||||
# Keep components with positive sensitivity, zero out negative
|
|
||||||
mask = (g > 0).float()
|
|
||||||
else:
|
|
||||||
# smooth_abs: sigmoid-weighted by sensitivity magnitude
|
|
||||||
# Normalize g to [-1, 1] range, apply sigmoid
|
|
||||||
g_norm = g / (g.abs().max() + 1e-8)
|
|
||||||
mask = torch.sigmoid(5.0 * g_norm) # steep sigmoid
|
|
||||||
|
|
||||||
# L1 norm preservation: scale mask so total nuclear norm is preserved
|
|
||||||
mask = mask * (S_r.sum() / (mask * S_r).sum().clamp(min=1e-8))
|
|
||||||
|
|
||||||
# Reconstruct: ΔW' = U_r @ diag(mask * S_r) @ Vt_r
|
|
||||||
S_new = mask * S_r
|
|
||||||
delta_W_new = U[:, :r] @ torch.diag(S_new) @ Vt[:r, :]
|
|
||||||
|
|
||||||
# Factor back into B' @ A' * scale: use SVD of ΔW'/scale
|
|
||||||
dW_unscaled = delta_W_new / mod.scale
|
|
||||||
U2, S2, Vt2 = torch.linalg.svd(dW_unscaled, full_matrices=False)
|
|
||||||
sqrt_S2 = S2[:r].sqrt()
|
|
||||||
A_new = sqrt_S2.unsqueeze(1) * Vt2[:r, :]
|
|
||||||
B_new = U2[:, :r] * sqrt_S2.unsqueeze(0)
|
|
||||||
|
|
||||||
ref_dtype = mod.lora_A.dtype
|
|
||||||
mod.lora_A.data = A_new.to(ref_dtype)
|
|
||||||
mod.lora_B.data = B_new.to(ref_dtype)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
kept = (mask > 0.5).sum().item()
|
|
||||||
print(f"[Spectral Surgery] {name}: kept {kept}/{r} components, "
|
|
||||||
f"sensitivity range [{g.min():.3f}, {g.max():.3f}]", flush=True)
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, state_dict: dict) -> None:
|
|
||||||
"""Load LoRA weights into a model that has already had apply_lora() called.
|
|
||||||
|
|
||||||
Non-LoRA keys in state_dict are ignored (strict=False). Non-LoRA model
|
|
||||||
parameters are not modified.
|
|
||||||
"""
|
|
||||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
||||||
bad = [k for k in unexpected if "lora_" not in k]
|
|
||||||
if bad:
|
|
||||||
print(f"[LoRA] Warning: unexpected non-LoRA keys ignored: {bad}")
|
|
||||||
lora_missing = [k for k in missing if "lora_" in k]
|
|
||||||
if lora_missing:
|
|
||||||
print(f"[LoRA] Warning: missing LoRA keys (wrong rank/target?): {lora_missing}")
|
|
||||||
-465
@@ -1,465 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
LoRA fine-tuning for SelVA / MMAudio generator.
|
|
||||||
|
|
||||||
Teaches the model new or partially-known sound classes from custom video+audio pairs.
|
|
||||||
Only the LoRA adapter weights are trained (~10 MB vs ~4.4 GB for the full model).
|
|
||||||
|
|
||||||
Data layout:
|
|
||||||
data/my_sound/
|
|
||||||
clip01.npz # visual features extracted by SelvaFeatureExtractor in ComfyUI
|
|
||||||
clip01.wav # paired clean audio (same filename stem, any format)
|
|
||||||
prompts.txt # optional: "clip01.npz: description" — overrides embedded prompt
|
|
||||||
|
|
||||||
If prompts.txt is absent, the prompt embedded in each .npz is used.
|
|
||||||
If the .npz has no embedded prompt, the directory name is used as fallback.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python train_lora.py \\
|
|
||||||
--data_dir data/my_sound \\
|
|
||||||
--output_dir lora_output \\
|
|
||||||
--variant large_44k \\
|
|
||||||
--selva_dir /path/to/ComfyUI/models/selva \\
|
|
||||||
--rank 16 --steps 2000 --lr 1e-4
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import random
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
import open_clip
|
|
||||||
from open_clip import create_model_from_pretrained
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
|
|
||||||
from selva_core.model.networks_generator import get_my_mmaudio
|
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils, patch_clip
|
|
||||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
|
||||||
from selva_core.model.lora import apply_lora, get_lora_state_dict
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Constants
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_VARIANTS = {
|
|
||||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k"),
|
|
||||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k"),
|
|
||||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k"),
|
|
||||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k"),
|
|
||||||
}
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Data helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def load_prompts(data_dir: Path) -> dict:
|
|
||||||
"""Load filename → prompt overrides from prompts.txt."""
|
|
||||||
p = data_dir / "prompts.txt"
|
|
||||||
if not p.exists():
|
|
||||||
return {}
|
|
||||||
mapping = {}
|
|
||||||
for line in p.read_text(encoding="utf-8").splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
if ":" in line:
|
|
||||||
fname, prompt = line.split(":", 1)
|
|
||||||
mapping[fname.strip()] = prompt.strip()
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
|
|
||||||
def find_audio_for_npz(npz_path: Path) -> Path | None:
|
|
||||||
"""Find a paired audio file with the same stem as the .npz."""
|
|
||||||
for ext in _AUDIO_EXTS:
|
|
||||||
candidate = npz_path.with_suffix(ext)
|
|
||||||
if candidate.exists():
|
|
||||||
return candidate
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
|
||||||
"""Load an audio file → [L] float32 [-1, 1], resampled and trimmed/padded to duration."""
|
|
||||||
waveform, sr = torchaudio.load(str(path))
|
|
||||||
|
|
||||||
# Stereo → mono
|
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = waveform.mean(0, keepdim=True)
|
|
||||||
waveform = waveform.squeeze(0).float()
|
|
||||||
|
|
||||||
# Resample
|
|
||||||
if sr != target_sr:
|
|
||||||
waveform = torchaudio.functional.resample(
|
|
||||||
waveform.unsqueeze(0), sr, target_sr
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
target_len = int(duration * target_sr)
|
|
||||||
if waveform.shape[0] >= target_len:
|
|
||||||
return waveform[:target_len]
|
|
||||||
return F.pad(waveform, (0, target_len - waveform.shape[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def load_npz(path: Path) -> dict:
|
|
||||||
"""Load a feature bundle produced by SelvaFeatureExtractor."""
|
|
||||||
data = np.load(str(path), allow_pickle=False)
|
|
||||||
bundle = {
|
|
||||||
"clip_features": torch.from_numpy(data["clip_features"]), # [1, N, 1024]
|
|
||||||
"sync_features": torch.from_numpy(data["sync_features"]), # [1, T, 768]
|
|
||||||
}
|
|
||||||
if "prompt" in data:
|
|
||||||
bundle["prompt"] = str(data["prompt"])
|
|
||||||
if "variant" in data:
|
|
||||||
bundle["variant"] = str(data["variant"])
|
|
||||||
return bundle
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Feature extraction (audio + text only — visual features come from .npz)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def encode_text_clip(clip_model, tokenizer, text: list[str], device) -> torch.Tensor:
|
|
||||||
tokens = tokenizer(text).to(device)
|
|
||||||
with torch.inference_mode():
|
|
||||||
return clip_model.encode_text(tokens, normalize=True)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_audio_latent(audio: torch.Tensor, feature_utils, device, dtype) -> torch.Tensor:
|
|
||||||
"""Encode a waveform to the generator's latent space via the VAE.
|
|
||||||
|
|
||||||
encode_audio is @inference_mode — .clone() is required before the autograd path.
|
|
||||||
"""
|
|
||||||
audio_b = audio.unsqueeze(0).to(device, dtype) # [1, L]
|
|
||||||
dist = feature_utils.encode_audio(audio_b)
|
|
||||||
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
|
||||||
return dist.mode().clone().transpose(1, 2).cpu() # [1, seq_len, latent_dim]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Main
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="LoRA fine-tuning for SelVA generator")
|
|
||||||
parser.add_argument("--data_dir", required=True, help="Directory with .npz + audio pairs and optional prompts.txt")
|
|
||||||
parser.add_argument("--output_dir", default="lora_output")
|
|
||||||
parser.add_argument("--variant", default="large_44k", choices=list(_VARIANTS.keys()))
|
|
||||||
parser.add_argument("--selva_dir", required=True, help="Path to selva model weights (ComfyUI/models/selva)")
|
|
||||||
parser.add_argument("--rank", type=int, default=16, help="LoRA rank")
|
|
||||||
parser.add_argument("--alpha", type=float, default=None, help="LoRA alpha (default: rank)")
|
|
||||||
parser.add_argument("--target", nargs="+", default=["attn.qkv"],
|
|
||||||
help="Module name suffixes to wrap with LoRA. Also try 'linear1'.")
|
|
||||||
parser.add_argument("--lr", type=float, default=1e-4)
|
|
||||||
parser.add_argument("--steps", type=int, default=2000)
|
|
||||||
parser.add_argument("--warmup_steps",type=int, default=100)
|
|
||||||
parser.add_argument("--batch_size", type=int, default=4, help="Clips per training step")
|
|
||||||
parser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation steps")
|
|
||||||
parser.add_argument("--save_every", type=int, default=500)
|
|
||||||
parser.add_argument("--resume", default=None,
|
|
||||||
help="Path to a step checkpoint (.pt) to resume training from.")
|
|
||||||
parser.add_argument("--precision", default="bf16", choices=["bf16", "fp16", "fp32"])
|
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
|
||||||
parser.add_argument("--timestep_mode", default="uniform", choices=["uniform", "logit_normal", "curriculum"],
|
|
||||||
help="Timestep sampling. uniform=original MMAudio, logit_normal=concentrated near t=0.5, curriculum=logit_normal then uniform.")
|
|
||||||
parser.add_argument("--logit_normal_sigma", type=float, default=1.0,
|
|
||||||
help="Spread of logit-normal distribution.")
|
|
||||||
parser.add_argument("--curriculum_switch", type=float, default=0.6,
|
|
||||||
help="Fraction of steps to use logit_normal before switching to uniform (curriculum mode only).")
|
|
||||||
parser.add_argument("--lora_dropout", type=float, default=0.0,
|
|
||||||
help="Dropout on the LoRA path only. 0.05–0.1 helps on small datasets.")
|
|
||||||
parser.add_argument("--lora_plus_ratio", type=float, default=1.0,
|
|
||||||
help="LoRA+ LR ratio: lr_B = lr * ratio. 1.0=standard, 16.0=LoRA+.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
|
||||||
random.seed(args.seed)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
if args.precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
|
|
||||||
print("[LoRA] bf16 not supported on this GPU — falling back to fp16")
|
|
||||||
args.precision = "fp16"
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[args.precision]
|
|
||||||
|
|
||||||
data_dir = Path(args.data_dir)
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
selva_dir = Path(args.selva_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
gen_filename, mode = _VARIANTS[args.variant]
|
|
||||||
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
|
||||||
duration = seq_cfg.duration
|
|
||||||
sample_rate = seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
# --- Weight paths ---
|
|
||||||
def w(name): return str(selva_dir / name)
|
|
||||||
def wext(name): return str(selva_dir / "ext" / name)
|
|
||||||
|
|
||||||
vae_weight = wext("v1-16.pth" if mode == "16k" else "v1-44.pth")
|
|
||||||
gen_weight = w(gen_filename)
|
|
||||||
for path, label in [(vae_weight, "VAE"), (gen_weight, "generator")]:
|
|
||||||
if not Path(path).exists():
|
|
||||||
print(f"[LoRA] Missing weight: {path} ({label})")
|
|
||||||
print("[LoRA] Run ComfyUI with SelvaModelLoader first to auto-download weights.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# --- Load CLIP text encoder (separate from FeaturesUtils to avoid loading Synchformer/T5) ---
|
|
||||||
print("[LoRA] Loading CLIP text encoder...")
|
|
||||||
clip_model = create_model_from_pretrained(
|
|
||||||
'hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False
|
|
||||||
).to(device, dtype).eval()
|
|
||||||
clip_model = patch_clip(clip_model)
|
|
||||||
tokenizer_clip = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
|
|
||||||
|
|
||||||
# --- Load VAE (FeaturesUtils with enable_conditions=False — no Synchformer/T5) ---
|
|
||||||
print("[LoRA] Loading VAE encoder...")
|
|
||||||
feature_utils = FeaturesUtils(
|
|
||||||
tod_vae_ckpt=vae_weight,
|
|
||||||
enable_conditions=False,
|
|
||||||
mode=mode,
|
|
||||||
need_vae_encoder=True,
|
|
||||||
).to(device, dtype).eval()
|
|
||||||
|
|
||||||
# --- Load generator ---
|
|
||||||
print(f"[LoRA] Loading generator ({args.variant})...")
|
|
||||||
net_generator = get_my_mmaudio(args.variant).to(device, dtype).eval()
|
|
||||||
net_generator.load_weights(
|
|
||||||
torch.load(gen_weight, map_location="cpu", weights_only=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Apply LoRA ---
|
|
||||||
n_lora = apply_lora(
|
|
||||||
net_generator,
|
|
||||||
rank=args.rank,
|
|
||||||
alpha=args.alpha,
|
|
||||||
target_suffixes=tuple(args.target),
|
|
||||||
dropout=args.lora_dropout,
|
|
||||||
)
|
|
||||||
print(f"[LoRA] Wrapped {n_lora} linear layers (rank={args.rank}, target={args.target}, dropout={args.lora_dropout})")
|
|
||||||
if n_lora == 0:
|
|
||||||
print("[LoRA] ERROR: no layers were wrapped — check --target names.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Freeze everything except LoRA params
|
|
||||||
for name, p in net_generator.named_parameters():
|
|
||||||
p.requires_grad_("lora_" in name)
|
|
||||||
|
|
||||||
trainable = sum(p.numel() for p in net_generator.parameters() if p.requires_grad)
|
|
||||||
total = sum(p.numel() for p in net_generator.parameters())
|
|
||||||
print(f"[LoRA] Trainable: {trainable:,} / {total:,} params "
|
|
||||||
f"({100 * trainable / total:.2f}%)")
|
|
||||||
|
|
||||||
net_generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=seq_cfg.clip_seq_len,
|
|
||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Dataset ---
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
if not npz_files:
|
|
||||||
print(f"[LoRA] No .npz files found in {data_dir}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
prompt_map = load_prompts(data_dir)
|
|
||||||
default_prompt = data_dir.name
|
|
||||||
|
|
||||||
print(f"[LoRA] Pre-loading {len(npz_files)} clip(s)...")
|
|
||||||
dataset = []
|
|
||||||
for npz_path in npz_files:
|
|
||||||
audio_path = find_audio_for_npz(npz_path)
|
|
||||||
if audio_path is None:
|
|
||||||
print(f" [LoRA] Warning: no audio file found for {npz_path.name} — skipping")
|
|
||||||
continue
|
|
||||||
|
|
||||||
bundle = load_npz(npz_path)
|
|
||||||
# Prompt priority: prompts.txt override > embedded in .npz > directory name
|
|
||||||
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
|
||||||
|
|
||||||
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'")
|
|
||||||
|
|
||||||
try:
|
|
||||||
audio = load_audio(audio_path, sample_rate, duration)
|
|
||||||
x1 = extract_audio_latent(audio, feature_utils, device, dtype)
|
|
||||||
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
|
||||||
tgt = seq_cfg.latent_seq_len
|
|
||||||
if x1.shape[1] < tgt:
|
|
||||||
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
|
||||||
elif x1.shape[1] > tgt:
|
|
||||||
x1 = x1[:, :tgt, :]
|
|
||||||
text_clip = encode_text_clip(clip_model, tokenizer_clip, [prompt], device).cpu()
|
|
||||||
|
|
||||||
# Pad/trim clip and sync features to fixed seq lengths — shorter clips
|
|
||||||
# have fewer frames and would cause stack() to fail during batching
|
|
||||||
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
|
||||||
c_tgt = seq_cfg.clip_seq_len
|
|
||||||
if clip_f.shape[1] < c_tgt:
|
|
||||||
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
|
||||||
elif clip_f.shape[1] > c_tgt:
|
|
||||||
clip_f = clip_f[:, :c_tgt, :]
|
|
||||||
|
|
||||||
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
|
||||||
s_tgt = seq_cfg.sync_seq_len
|
|
||||||
if sync_f.shape[1] < s_tgt:
|
|
||||||
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
|
||||||
elif sync_f.shape[1] > s_tgt:
|
|
||||||
sync_f = sync_f[:, :s_tgt, :]
|
|
||||||
|
|
||||||
dataset.append((x1, clip_f, sync_f, text_clip))
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [LoRA] Warning: failed to process {npz_path.name}: {e}")
|
|
||||||
|
|
||||||
if not dataset:
|
|
||||||
print("[LoRA] No clips could be loaded.")
|
|
||||||
sys.exit(1)
|
|
||||||
print(f"[LoRA] {len(dataset)} clip(s) ready.")
|
|
||||||
|
|
||||||
# --- Optimizer + LR scheduler ---
|
|
||||||
# LoRA+: separate param groups for A and B with different LRs.
|
|
||||||
# ratio=1.0 = standard LoRA. ratio=16 = LoRA+ (arXiv:2402.12354).
|
|
||||||
lora_A_params = [p for n, p in net_generator.named_parameters() if "lora_A" in n and p.requires_grad]
|
|
||||||
lora_B_params = [p for n, p in net_generator.named_parameters() if "lora_B" in n and p.requires_grad]
|
|
||||||
optimizer = torch.optim.AdamW([
|
|
||||||
{"params": lora_A_params, "lr": args.lr},
|
|
||||||
{"params": lora_B_params, "lr": args.lr * args.lora_plus_ratio},
|
|
||||||
], weight_decay=1e-2)
|
|
||||||
if args.lora_plus_ratio != 1.0:
|
|
||||||
print(f"[LoRA] LoRA+: lr_A={args.lr:.2e} lr_B={args.lr * args.lora_plus_ratio:.2e}")
|
|
||||||
|
|
||||||
def lr_lambda(step):
|
|
||||||
if step < args.warmup_steps:
|
|
||||||
return step / max(1, args.warmup_steps)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
|
||||||
|
|
||||||
# --- Resume ---
|
|
||||||
start_step = 0
|
|
||||||
if args.resume:
|
|
||||||
ckpt = torch.load(args.resume, map_location="cpu", weights_only=False)
|
|
||||||
if "step" not in ckpt:
|
|
||||||
print("[LoRA] ERROR: checkpoint has no step info — was it saved by this script?")
|
|
||||||
sys.exit(1)
|
|
||||||
start_step = ckpt["step"]
|
|
||||||
if start_step >= args.steps:
|
|
||||||
print(f"[LoRA] Checkpoint is already at step {start_step} >= --steps {args.steps}. Nothing to do.")
|
|
||||||
sys.exit(0)
|
|
||||||
net_generator.load_state_dict(ckpt["state_dict"], strict=False)
|
|
||||||
optimizer.load_state_dict(ckpt["optimizer"])
|
|
||||||
scheduler.load_state_dict(ckpt["scheduler"])
|
|
||||||
print(f"[LoRA] Resumed from {Path(args.resume).name} (step {start_step} → {args.steps})")
|
|
||||||
|
|
||||||
# --- Training loop ---
|
|
||||||
net_generator.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
remaining = args.steps - start_step
|
|
||||||
print(f"\n[LoRA] Training: {remaining} steps (step {start_step + 1} → {args.steps}), "
|
|
||||||
f"batch_size={args.batch_size}, lr={args.lr}, grad_accum={args.grad_accum}")
|
|
||||||
print(f"[LoRA] Checkpoints every {args.save_every} steps → {output_dir}\n")
|
|
||||||
|
|
||||||
curriculum_switch_step = start_step + int((args.steps - start_step) * args.curriculum_switch)
|
|
||||||
_curriculum_switched = False
|
|
||||||
|
|
||||||
total_loss = 0.0
|
|
||||||
for step in range(start_step + 1, args.steps + 1):
|
|
||||||
batch = random.choices(dataset, k=args.batch_size)
|
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
|
||||||
|
|
||||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
|
||||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
|
||||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
|
||||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
|
||||||
|
|
||||||
net_generator.normalize(x1)
|
|
||||||
|
|
||||||
if args.timestep_mode == "logit_normal" or (
|
|
||||||
args.timestep_mode == "curriculum" and step <= curriculum_switch_step
|
|
||||||
):
|
|
||||||
u = torch.randn(args.batch_size, device=device, dtype=dtype) * args.logit_normal_sigma
|
|
||||||
t = torch.sigmoid(u)
|
|
||||||
else:
|
|
||||||
t = torch.rand(args.batch_size, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if args.timestep_mode == "curriculum" and step == curriculum_switch_step + 1 and not _curriculum_switched:
|
|
||||||
print(f"[LoRA] Curriculum switch: logit_normal → uniform at step {step}")
|
|
||||||
_curriculum_switched = True
|
|
||||||
|
|
||||||
x0 = torch.randn_like(x1)
|
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
|
||||||
|
|
||||||
v_pred = net_generator.forward(xt, clip_f, sync_f, text_clip, t)
|
|
||||||
|
|
||||||
loss = fm.loss(v_pred, x0, x1).mean() / args.grad_accum
|
|
||||||
loss.backward()
|
|
||||||
total_loss += loss.item() * args.grad_accum
|
|
||||||
|
|
||||||
if step % args.grad_accum == 0:
|
|
||||||
torch.nn.utils.clip_grad_norm_(lora_A_params + lora_B_params, max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
if step % 50 == 0:
|
|
||||||
avg = total_loss / 50
|
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
|
||||||
print(f"[LoRA] step {step:5d}/{args.steps} loss={avg:.4f} lr={lr_now:.2e}")
|
|
||||||
total_loss = 0.0
|
|
||||||
|
|
||||||
if step % args.save_every == 0 or step == args.steps:
|
|
||||||
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
|
||||||
torch.save({
|
|
||||||
"state_dict": get_lora_state_dict(net_generator),
|
|
||||||
"optimizer": optimizer.state_dict(),
|
|
||||||
"scheduler": scheduler.state_dict(),
|
|
||||||
"step": step,
|
|
||||||
"meta": {
|
|
||||||
"variant": args.variant,
|
|
||||||
"rank": args.rank,
|
|
||||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
|
||||||
"target": args.target,
|
|
||||||
"steps": args.steps,
|
|
||||||
"timestep_mode": args.timestep_mode,
|
|
||||||
"logit_normal_sigma": args.logit_normal_sigma,
|
|
||||||
"curriculum_switch": args.curriculum_switch,
|
|
||||||
"lora_dropout": args.lora_dropout,
|
|
||||||
"lora_plus_ratio": args.lora_plus_ratio,
|
|
||||||
},
|
|
||||||
}, ckpt_path)
|
|
||||||
print(f"[LoRA] Saved {ckpt_path}")
|
|
||||||
|
|
||||||
# Save final adapter with embedded metadata
|
|
||||||
# Increment filename if a previous final already exists (resume case)
|
|
||||||
final = output_dir / "adapter_final.pt"
|
|
||||||
if final.exists():
|
|
||||||
i = 1
|
|
||||||
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
|
||||||
i += 1
|
|
||||||
final = output_dir / f"adapter_final_{i:03d}.pt"
|
|
||||||
meta = {
|
|
||||||
"variant": args.variant,
|
|
||||||
"rank": args.rank,
|
|
||||||
"alpha": args.alpha if args.alpha is not None else float(args.rank),
|
|
||||||
"target": args.target,
|
|
||||||
"steps": args.steps,
|
|
||||||
"timestep_mode": args.timestep_mode,
|
|
||||||
"logit_normal_sigma": args.logit_normal_sigma,
|
|
||||||
"curriculum_switch": args.curriculum_switch,
|
|
||||||
"lora_dropout": args.lora_dropout,
|
|
||||||
"lora_plus_ratio": args.lora_plus_ratio,
|
|
||||||
}
|
|
||||||
torch.save({"state_dict": get_lora_state_dict(net_generator), "meta": meta}, final)
|
|
||||||
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
|
||||||
print(f"\n[LoRA] Training complete. Adapter saved to {final}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user