Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0f60a9b2bf | |||
| 51f93f9688 | |||
| a315093743 | |||
| e49f760b77 | |||
| 4f40e15db3 | |||
| 08d73773c5 |
@@ -1,459 +0,0 @@
|
|||||||
# LoRA Training for SelVA
|
|
||||||
|
|
||||||
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Training is split into two steps:
|
|
||||||
|
|
||||||
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
|
|
||||||
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
|
|
||||||
|
|
||||||
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 3–4 GB of VRAM.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
Same environment as SelVA inference. Additional Python packages:
|
|
||||||
|
|
||||||
```
|
|
||||||
torchaudio
|
|
||||||
soundfile
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 1 — Prepare the dataset
|
|
||||||
|
|
||||||
### 1.1 Video format
|
|
||||||
|
|
||||||
The feature extractor accepts any input but internally resamples frames to fixed square resolutions (384×384 for CLIP, 224×224 for Synchformer). Both encoders were trained on standard video datasets — predominantly landscape footage. This has two practical implications:
|
|
||||||
|
|
||||||
**Aspect ratio** — use **16:9 landscape** whenever possible. Portrait clips (9:16) are mechanically supported but the bicubic stretch into square distorts the image relative to the encoders' training distribution, which can degrade sync feature quality. If your source is portrait, center-crop to square before extraction. Square (1:1) is also fine.
|
|
||||||
|
|
||||||
**Resolution** — anything ≥ 480p is sufficient. The extractor downscales to 384px and 224px regardless of source resolution; higher resolution adds no benefit.
|
|
||||||
|
|
||||||
**Frame rate** — any. Connect `VHS_VIDEOINFO` from VHS LoadVideo to the feature extractor so fps is read automatically from the file instead of being entered manually.
|
|
||||||
|
|
||||||
| Format | Recommendation |
|
|
||||||
|---|---|
|
|
||||||
| Aspect ratio | 16:9 landscape (preferred) or 1:1 square |
|
|
||||||
| Resolution | ≥ 480p (720p+ is fine, no upper limit that matters) |
|
|
||||||
| Frame rate | Any — set via VHS_VIDEOINFO |
|
|
||||||
| Portrait (9:16) | Center-crop to square before extraction |
|
|
||||||
|
|
||||||
### 1.2 Extract visual features in ComfyUI
|
|
||||||
|
|
||||||
For each video clip you want to train on:
|
|
||||||
|
|
||||||
1. Load the video with a VHS LoadVideo node.
|
|
||||||
2. Connect it to **SelVA Feature Extractor**.
|
|
||||||
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
|
|
||||||
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
|
|
||||||
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
|
|
||||||
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
|
|
||||||
|
|
||||||
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
|
|
||||||
|
|
||||||
### Prompt guide
|
|
||||||
|
|
||||||
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
|
|
||||||
|
|
||||||
**Good prompts are specific about:**
|
|
||||||
- The sound source (what object is making the sound)
|
|
||||||
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
|
|
||||||
- The action producing the sound (if applicable)
|
|
||||||
|
|
||||||
| Sound | Weak prompt | Strong prompt |
|
|
||||||
|---|---|---|
|
|
||||||
| Dog bark | `dog` | `a large dog barking loudly` |
|
|
||||||
| Footsteps | `walking` | `heavy boots on a wooden floor` |
|
|
||||||
| Water | `water` | `water dripping into a metal bucket` |
|
|
||||||
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
|
|
||||||
| Door | `door` | `a heavy wooden door slamming shut` |
|
|
||||||
|
|
||||||
**Rules of thumb:**
|
|
||||||
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
|
|
||||||
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
|
|
||||||
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
|
|
||||||
|
|
||||||
### Masking note
|
|
||||||
|
|
||||||
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
|
|
||||||
|
|
||||||
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
|
|
||||||
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
|
|
||||||
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
|
|
||||||
|
|
||||||
### 1.3 Collect clean audio
|
|
||||||
|
|
||||||
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
|
|
||||||
|
|
||||||
```
|
|
||||||
dataset/my_sound/
|
|
||||||
dog_bark_001.npz ← from SelVA Feature Extractor
|
|
||||||
dog_bark_001.wav ← clean isolated audio recording
|
|
||||||
dog_bark_002.npz
|
|
||||||
dog_bark_002.wav
|
|
||||||
dog_bark_003.npz
|
|
||||||
dog_bark_003.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
|
|
||||||
|
|
||||||
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
|
|
||||||
|
|
||||||
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
|
|
||||||
|
|
||||||
### 1.4 Optional: prompts.txt
|
|
||||||
|
|
||||||
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
|
|
||||||
|
|
||||||
```
|
|
||||||
# One line per file: filename: prompt text
|
|
||||||
dog_bark.npz: a large dog barking aggressively
|
|
||||||
dog_bark_001.npz: a dog barking in the distance
|
|
||||||
```
|
|
||||||
|
|
||||||
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 2 — Train
|
|
||||||
|
|
||||||
### Option A — SelVA LoRA Trainer node (ComfyUI)
|
|
||||||
|
|
||||||
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option B — Command line
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python train_lora.py \
|
|
||||||
--data_dir dataset/my_sound \
|
|
||||||
--output_dir lora_output/my_sound \
|
|
||||||
--variant large_44k \
|
|
||||||
--selva_dir /path/to/ComfyUI/models/selva \
|
|
||||||
--rank 16 \
|
|
||||||
--steps 4000 \
|
|
||||||
--batch_size 4 \
|
|
||||||
--lr 1e-4
|
|
||||||
```
|
|
||||||
|
|
||||||
The script will:
|
|
||||||
1. Load the VAE, CLIP text encoder, and generator.
|
|
||||||
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
|
|
||||||
3. Train LoRA adapters for the specified number of steps.
|
|
||||||
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## CLI Reference
|
|
||||||
|
|
||||||
| Argument | Default | Description |
|
|
||||||
|---|---|---|
|
|
||||||
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
|
|
||||||
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
|
|
||||||
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
|
|
||||||
| `--selva_dir` | required | Path to SelVA model weights directory |
|
|
||||||
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
|
|
||||||
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
|
|
||||||
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
|
|
||||||
| `--lr` | `1e-4` | Learning rate |
|
|
||||||
| `--steps` | `2000` | Total training steps |
|
|
||||||
| `--warmup_steps` | `100` | Linear LR warmup steps |
|
|
||||||
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
|
|
||||||
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
|
|
||||||
| `--save_every` | `500` | Save a checkpoint every N steps |
|
|
||||||
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
|
|
||||||
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
|
||||||
| `--seed` | `42` | Random seed |
|
|
||||||
| `--timestep_mode` | `uniform` | Timestep sampling: `uniform`, `logit_normal`, or `curriculum` |
|
|
||||||
| `--logit_normal_sigma` | `1.0` | Spread of the logit-normal distribution. Only used with `logit_normal` / `curriculum` |
|
|
||||||
| `--curriculum_switch` | `0.6` | Fraction of steps to use logit_normal before switching to uniform. Only with `curriculum` |
|
|
||||||
| `--lora_dropout` | `0.0` | Dropout on the LoRA path only. `0.05`–`0.1` helps regularize on small datasets |
|
|
||||||
| `--lora_plus_ratio` | `1.0` | LoRA+ LR ratio: `lr_B = lr × ratio`. `1.0` = standard LoRA, `16.0` = LoRA+ |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step 3 — Load the adapter in ComfyUI
|
|
||||||
|
|
||||||
Connect **SelVA LoRA Loader** between the model loader and the sampler:
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
|
|
||||||
```
|
|
||||||
|
|
||||||
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|---|---|
|
|
||||||
| `model` | SELVA_MODEL from the model loader |
|
|
||||||
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
|
|
||||||
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
|
|
||||||
|
|
||||||
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
|
|
||||||
|
|
||||||
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tuning Guide
|
|
||||||
|
|
||||||
### Clip length
|
|
||||||
|
|
||||||
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
|
|
||||||
|
|
||||||
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
|
|
||||||
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
|
|
||||||
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
|
|
||||||
|
|
||||||
**Practical recommendations:**
|
|
||||||
|
|
||||||
| Sound type | Clip strategy |
|
|
||||||
|---|---|
|
|
||||||
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
|
|
||||||
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 2–3 times per clip |
|
|
||||||
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
|
|
||||||
| Sound with a clear onset (explosion, splash) | Put the onset at ~1–2 s from the start, not at 0 s — gives the model context |
|
|
||||||
|
|
||||||
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
|
|
||||||
|
|
||||||
### How many clips do I need?
|
|
||||||
|
|
||||||
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
|
|
||||||
|
|
||||||
| Dataset size | Scenario | Expected result |
|
|
||||||
|---|---|---|
|
|
||||||
| **5–10 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
|
|
||||||
| **15–30 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
|
|
||||||
| **30–60 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
|
|
||||||
| **60–150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
|
|
||||||
| **150–300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
|
|
||||||
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
|
|
||||||
|
|
||||||
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
|
|
||||||
|
|
||||||
### Batch size
|
|
||||||
|
|
||||||
| Batch size | VRAM (large_44k) | Use case |
|
|
||||||
|---|---|---|
|
|
||||||
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
|
|
||||||
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
|
|
||||||
| `8` | ~15 GB | Better convergence on larger datasets |
|
|
||||||
| `16` | ~20 GB | Best gradient quality when VRAM allows |
|
|
||||||
|
|
||||||
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
|
|
||||||
|
|
||||||
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
|
|
||||||
|
|
||||||
### Rank
|
|
||||||
|
|
||||||
| Rank | Use case |
|
|
||||||
|---|---|
|
|
||||||
| `8` | Fine details on a sound the model already knows well |
|
|
||||||
| `16` | Default — good balance of capacity and VRAM |
|
|
||||||
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
|
|
||||||
|
|
||||||
Higher rank increases VRAM usage and overfitting risk on small datasets.
|
|
||||||
|
|
||||||
### Steps
|
|
||||||
|
|
||||||
With `batch_size=4` as the default, these are rough guidelines:
|
|
||||||
|
|
||||||
| Dataset size | Recommended steps |
|
|
||||||
|---|---|
|
|
||||||
| 10–20 clips | 2000–4000 |
|
|
||||||
| 20–50 clips | 4000–8000 |
|
|
||||||
| 50+ clips | 6000–15000 |
|
|
||||||
|
|
||||||
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
|
|
||||||
|
|
||||||
### Learning rate
|
|
||||||
|
|
||||||
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
|
|
||||||
|
|
||||||
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
|
|
||||||
|
|
||||||
### Target layers
|
|
||||||
|
|
||||||
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
|
|
||||||
|
|
||||||
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
--target attn.qkv linear1
|
|
||||||
```
|
|
||||||
|
|
||||||
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
|
|
||||||
|
|
||||||
### Timestep sampling mode
|
|
||||||
|
|
||||||
Controls how training timesteps are sampled at each step.
|
|
||||||
|
|
||||||
`uniform` (default) samples all timesteps equally — equivalent to original MMAudio training.
|
|
||||||
|
|
||||||
`logit_normal` concentrates more steps near t=0.5 via `sigmoid(N(0, σ))`. This is the semantically rich mid-noise region. Consistently reaches a lower loss floor but the perceptual improvement on small datasets is marginal.
|
|
||||||
|
|
||||||
`curriculum` uses logit_normal for the first `curriculum_switch` fraction of steps (default 60%), then switches to uniform for the remainder. The motivation: logit_normal accelerates early structure learning but undertrains the high-t boundary region; uniform then fills in the fine detail. A switch message is logged when the transition happens.
|
|
||||||
|
|
||||||
| Mode | When to use |
|
|
||||||
|---|---|
|
|
||||||
| `uniform` (default) | Baseline — safe, equivalent to original training |
|
|
||||||
| `logit_normal` | When you want a lower loss floor; marginal on small datasets |
|
|
||||||
| `curriculum` | Experimental — may improve convergence quality on small datasets |
|
|
||||||
|
|
||||||
The `logit_normal_sigma` parameter controls the width of the logit-normal distribution (used by both `logit_normal` and the first phase of `curriculum`):
|
|
||||||
- σ=1.0: moderate peak at t=0.5, balanced coverage (default)
|
|
||||||
- σ=0.5: sharper peak, less coverage of extremes
|
|
||||||
- σ=2.0: broader, approaches uniform
|
|
||||||
|
|
||||||
### LoRA dropout
|
|
||||||
|
|
||||||
`lora_dropout` applies dropout to the input of the LoRA path (not the frozen base linear). It regularizes the low-rank update without disturbing pretrained weights — helpful on small datasets where the LoRA would otherwise overfit to the training clips.
|
|
||||||
|
|
||||||
| Value | Use case |
|
|
||||||
|---|---|
|
|
||||||
| `0.0` (default) | No regularization — fine for 30+ clips |
|
|
||||||
| `0.05` | Light regularization — recommended starting point on 10–20 clips |
|
|
||||||
| `0.1` | Stronger regularization — use if loss plateaus but audio is still noisy |
|
|
||||||
|
|
||||||
Dropout is not saved in the adapter file — it only affects training. Loading the adapter at inference does not require setting dropout.
|
|
||||||
|
|
||||||
### LoRA+ (asymmetric learning rate)
|
|
||||||
|
|
||||||
`lora_plus_ratio` splits the learning rate between LoRA A and B matrices: `lr_B = lr × ratio`. The B matrix is the output-side projection and benefits from a higher LR. Setting ratio to 16 enables the LoRA+ scheme from arXiv:2402.12354.
|
|
||||||
|
|
||||||
| Ratio | Effect |
|
|
||||||
|---|---|
|
|
||||||
| `1.0` (default) | Standard LoRA — identical A and B learning rates |
|
|
||||||
| `4.0` | Mild asymmetry |
|
|
||||||
| `16.0` | LoRA+ — faster convergence, especially on early steps |
|
|
||||||
|
|
||||||
LoRA+ is orthogonal to dropout and curriculum sampling — all three can be combined.
|
|
||||||
|
|
||||||
### Adapter strength at inference
|
|
||||||
|
|
||||||
| Strength | Effect |
|
|
||||||
|---|---|
|
|
||||||
| `0.5–0.7` | Conservative — blends adapter with base model, less noise |
|
|
||||||
| `1.0` | Full adapter strength (default) |
|
|
||||||
| `>1.0` | Exaggerated effect, may introduce artifacts |
|
|
||||||
|
|
||||||
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.6–0.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
|
|
||||||
|
|
||||||
### Loss interpretation
|
|
||||||
|
|
||||||
A typical loss curve:
|
|
||||||
- Starts around `0.8–1.0`
|
|
||||||
- Should reach `0.55–0.65` after convergence on a clean sound class with 10–30 clips
|
|
||||||
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
|
|
||||||
- Below `0.1` on a small dataset means overfitting
|
|
||||||
|
|
||||||
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
|
|
||||||
|
|
||||||
### Precision
|
|
||||||
|
|
||||||
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Output files
|
|
||||||
|
|
||||||
```
|
|
||||||
lora_output/my_sound/
|
|
||||||
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
|
|
||||||
adapter_step01000.pt
|
|
||||||
...
|
|
||||||
adapter_final.pt ← final adapter with embedded metadata (inference only)
|
|
||||||
meta.json ← human-readable metadata
|
|
||||||
sample_step00500.wav ← quick eval sample at each checkpoint
|
|
||||||
loss_raw.png ← raw loss curve
|
|
||||||
loss_smoothed.png ← EMA-smoothed loss curve
|
|
||||||
```
|
|
||||||
|
|
||||||
`adapter_final.pt` format:
|
|
||||||
```python
|
|
||||||
{
|
|
||||||
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
|
|
||||||
"meta": {
|
|
||||||
"variant": "large_44k",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 16.0,
|
|
||||||
"target": ["attn.qkv"],
|
|
||||||
"steps": 2000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
**`No layers matched target=...`**
|
|
||||||
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
|
|
||||||
|
|
||||||
**`No .npz files found in ...`**
|
|
||||||
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
|
|
||||||
|
|
||||||
**`No audio file found for clip.npz`**
|
|
||||||
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
|
|
||||||
|
|
||||||
**The sound is audible but there is white noise on top**
|
|
||||||
Lower the adapter strength to `0.6–0.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
|
|
||||||
|
|
||||||
**LoRA appears to have no effect**
|
|
||||||
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
|
|
||||||
|
|
||||||
**Loss does not decrease**
|
|
||||||
- Increase `batch_size` for more stable gradients.
|
|
||||||
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
|
|
||||||
- Check that the audio files are clean and actually contain the target sound.
|
|
||||||
- Check that the `.npz` features were extracted with a relevant prompt.
|
|
||||||
|
|
||||||
**Loss explodes or NaN**
|
|
||||||
- Lower the learning rate (`5e-5`).
|
|
||||||
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
|
|
||||||
|
|
||||||
**Loss plateaus early (above 0.7)**
|
|
||||||
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Observations (work in progress)
|
|
||||||
|
|
||||||
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
|
|
||||||
|
|
||||||
### Precision and batch size
|
|
||||||
|
|
||||||
| Config | Smoothed loss at step 2000 | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
|
|
||||||
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 6000–8000 at ~0.59 |
|
|
||||||
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
|
|
||||||
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
|
|
||||||
|
|
||||||
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
|
|
||||||
|
|
||||||
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
|
|
||||||
|
|
||||||
### logit_normal vs uniform
|
|
||||||
|
|
||||||
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
|
|
||||||
|
|
||||||
### White noise
|
|
||||||
|
|
||||||
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
|
|
||||||
- Too few clips for the model to confidently predict the target sound
|
|
||||||
- Imprecise extraction prompts producing unfocused sync features
|
|
||||||
- Missing mask when multiple objects are in frame
|
|
||||||
|
|
||||||
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.0–3.5 or adapter strength to 0.6–0.7 helps at inference.
|
|
||||||
@@ -1,380 +1,156 @@
|
|||||||
# ComfyUI-SelVA
|
# ComfyUI-PrismAudio
|
||||||
|
|
||||||
Custom nodes for [SelVA](https://github.com/jnwnlee/selva) — video-to-audio generation driven by text prompts. SelVA conditions audio synthesis on both visual content and natural language, letting you describe *what* sounds to generate rather than just *when*.
|
Custom nodes for [PrismAudio](https://huggingface.co/FunAudioLLM/PrismAudio) (ICLR 2026) — video-to-audio and text-to-audio generation using decomposed Chain-of-Thought reasoning with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
||||||
|
|
||||||
Built on [MMAudio](https://github.com/hkchengrex/MMAudio) with a TextSynchformer encoder that injects text guidance directly into the visual sync stream.
|
## Installation
|
||||||
|
|
||||||
---
|
Clone into your ComfyUI custom nodes directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ComfyUI/custom_nodes
|
||||||
|
git clone https://github.com/Ethanfel/ComfyUI-Prismaudio.git ComfyUI-PrismAudio
|
||||||
|
pip install -r ComfyUI-PrismAudio/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**flash-attn** is optional — detected at runtime, falls back to PyTorch SDPA if unavailable.
|
||||||
|
|
||||||
## Nodes
|
## Nodes
|
||||||
|
|
||||||
### SelVA Model Loader
|
### PrismAudio Model Loader
|
||||||
|
|
||||||
Loads the generator, TextSynchformer encoder, and all feature utilities (CLIP, T5, Synchformer, VAE). Weights are auto-downloaded from HuggingFace on first use.
|
Loads the DiT diffusion model and VAE. Auto-downloads weights from HuggingFace on first use.
|
||||||
|
|
||||||
| Input | Options | Description |
|
| Input | Options | Description |
|
||||||
|-------|---------|-------------|
|
|-------|---------|-------------|
|
||||||
| `variant` | small_16k / small_44k / medium_44k / large_44k | Model size and output sample rate |
|
| `precision` | auto / fp32 / fp16 / bf16 | DiT and conditioner dtype. VAE is always fp32. |
|
||||||
| `precision` | bf16 / fp16 / fp32 | Compute dtype |
|
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management. |
|
||||||
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management |
|
|
||||||
|
|
||||||
**Output:** `model` (SELVA_MODEL)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### SelVA Feature Extractor
|
### PrismAudio Feature Extractor
|
||||||
|
|
||||||
Extracts CLIP visual features and text-guided sync features from a video. Results are cached on disk — re-running with the same inputs is instant.
|
Extracts video features (VideoPrism LvT, Synchformer) and text features (T5-Gemma) from a video in a subprocess. Results are cached on disk.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `model` | From SelVA Model Loader |
|
|
||||||
| `video` | IMAGE tensor from any ComfyUI video loader |
|
| `video` | IMAGE tensor from any ComfyUI video loader |
|
||||||
| `prompt` | Text description of the audio to generate |
|
| `caption_cot` | Chain-of-thought description of the audio scene |
|
||||||
| `video_info` | *(optional)* VHS_VIDEOINFO from VHS LoadVideo — sets fps automatically |
|
| `video_info` | *(optional)* `VHS_VIDEOINFO` from VHS LoadVideo — sets fps automatically |
|
||||||
| `fps` | Source fps — ignored if `video_info` is connected |
|
| `fps` | Source fps — ignored if `video_info` is connected |
|
||||||
| `duration` | Override clip duration in seconds. `0` = infer from video length |
|
| `python_env` | `managed_env` (auto-created isolated venv, recommended) or `comfyui_env` (current Python, see warning below) |
|
||||||
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir |
|
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir. |
|
||||||
| `mask` | *(optional)* Segmentation mask `[T,H,W]` float [0,1] — static (1 frame) or per-frame |
|
| `hf_token` | HuggingFace token for gated models. Prefer `HF_TOKEN` env var instead. |
|
||||||
| `mask_strength` | Background suppression strength. `1.0` = full neutral fill, `0.0` = no effect |
|
|
||||||
| `mask_clip` | Apply mask to CLIP features (384px path). Disable to let CLIP see the full scene |
|
|
||||||
| `mask_sync` | Apply mask to TextSynchformer sync features (224px path) |
|
|
||||||
|
|
||||||
**Outputs:** `features` (SELVA_FEATURES), `fps` (FLOAT), `prompt` (STRING)
|
**Outputs:** `features` (PRISMAUDIO_FEATURES), `fps` (FLOAT)
|
||||||
|
|
||||||
Connect `prompt` output to the Sampler's `prompt` input to avoid entering it twice.
|
**`managed_env`** auto-creates a venv at `_extract_env/` inside the plugin directory on first use and installs JAX, TF, VideoPrism, and Synchformer. This takes several minutes the first time.
|
||||||
|
|
||||||
#### Masking
|
**`comfyui_env`** uses the current ComfyUI Python — JAX/TF/videoprism must already be installed. Installing them into the ComfyUI environment may conflict with existing packages.
|
||||||
|
|
||||||
Connect a segmentation mask (SAM2, Grounding DINO+SAM, or any ComfyUI mask node) to isolate a specific object's motion before encoding. Background pixels are filled with a neutral value (0.5) rather than zeroed — this keeps them in-distribution for CLIP and maps to exactly 0 after sync's `[-1,1]` normalization, minimising the influence of background motion on the generated audio.
|
|
||||||
|
|
||||||
Use `mask_sync=true, mask_clip=false` if you want sync features focused on the target object while CLIP still sees the full scene for broader context. Changing any mask parameter correctly busts the feature cache.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### SelVA Sampler
|
### PrismAudio Feature Loader
|
||||||
|
|
||||||
Generates audio from video features. Runs the rectified flow ODE with classifier-free guidance.
|
Loads a pre-computed `.npz` feature file. Use this to re-use extracted features without re-running the extractor.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `model` | From SelVA Model Loader (or any loader/loader chain) |
|
| `npz_path` | Path to a `.npz` file produced by the Feature Extractor |
|
||||||
| `features` | From SelVA Feature Extractor |
|
|
||||||
| `prompt` | Text description — leave empty to use the prompt stored in features |
|
|
||||||
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
|
|
||||||
| `duration` | Audio duration in seconds. `0` = use duration from features |
|
|
||||||
| `steps` | Sampling steps (default: 25) |
|
|
||||||
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
| `normalize` | RMS-normalize output to `target_lufs` (default: true) |
|
|
||||||
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
|
|
||||||
| `steering_vectors` | *(optional)* From SelVA Activation Steering Loader |
|
|
||||||
| `steering_strength` | *(optional)* Scale for steering vectors (default: 0.1) |
|
|
||||||
| `textual_inversion` | *(optional)* From SelVA Textual Inversion Loader |
|
|
||||||
| `ti_strength` | *(optional)* Blend strength for TI tokens (default: 1.0) |
|
|
||||||
|
|
||||||
**Output:** `AUDIO`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### SelVA LoRA Loader
|
### PrismAudio Sampler
|
||||||
|
|
||||||
Injects a trained LoRA adapter into the generator. Connect between Model Loader and Sampler.
|
Video-to-audio generation. Takes model + features, produces AUDIO.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `model` | SELVA_MODEL from Model Loader |
|
| `model` | From Model Loader |
|
||||||
| `adapter_path` | Path to `adapter_final.pt` or any step checkpoint |
|
| `features` | From Feature Extractor or Feature Loader |
|
||||||
| `strength` | 0.0 = disabled, 1.0 = full, >1.0 = exaggerated |
|
| `duration` | Audio duration in seconds. Set to `0` to use the video duration from features automatically. |
|
||||||
|
| `steps` | Sampling steps (default: 100) |
|
||||||
**Output:** `model` (SELVA_MODEL with adapter injected)
|
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Trainer
|
|
||||||
|
|
||||||
Fine-tunes LoRA adapters on a `.npz` feature dataset. See [LORA_TRAINING.md](LORA_TRAINING.md) for the full guide.
|
|
||||||
|
|
||||||
**Output:** `adapter` (SELVA_LORA) and `summary_path` (STRING)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Scheduler
|
|
||||||
|
|
||||||
Runs a series of LoRA experiments from a JSON sweep file. The dataset is encoded once and reused across all runs. Results are collected in `experiment_summary.json` with overlaid loss curves.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `experiments_file` | Path to JSON sweep config |
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Skip Experiment
|
|
||||||
|
|
||||||
Signals a running SelVA LoRA Scheduler to skip the current experiment and move to the next. Queue this node while the scheduler is running.
|
|
||||||
|
|
||||||
**Output:** `flag_path` (STRING)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA LoRA Evaluator
|
|
||||||
|
|
||||||
Evaluates multiple LoRA adapters by generating audio from a fixed reference clip, then reports spectral metrics per adapter for comparison. Input is a JSON file listing adapter paths; an empty path means baseline (no LoRA).
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_image` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Dataset Browser
|
|
||||||
|
|
||||||
Reads a `dataset.json` produced by the SelVA dataset preparation pipeline and exposes one entry at a time via an index. Useful for previewing and iterating through a prepared dataset.
|
|
||||||
|
|
||||||
**Outputs:** video path, audio path, frames directory, label, total count
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA VAE Roundtrip
|
|
||||||
|
|
||||||
Encodes audio through the SelVA VAE then decodes it back. Use this to measure codec reconstruction quality in isolation — if the output sounds degraded relative to the input, the codec ceiling will limit any downstream fine-tuning approach.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `audio` | AUDIO to test |
|
|
||||||
|
|
||||||
**Output:** `audio_reconstructed` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA HF Smoother
|
|
||||||
|
|
||||||
Attenuates high-frequency content that the SelVA codec handles poorly, by blending a low-pass filtered version of the audio with the original. Use before feature extraction to improve LoRA training targets.
|
|
||||||
|
|
||||||
**Output:** `audio` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Spectral Matcher
|
|
||||||
|
|
||||||
Applies a per-band gain correction to bring audio's spectral profile in line with the MMAudio VAE's expected distribution, derived from the normalization statistics baked into the VAE weights. Use on training audio to reduce codec mismatch.
|
|
||||||
|
|
||||||
**Output:** `audio` (AUDIO)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Textual Inversion Trainer
|
|
||||||
|
|
||||||
Trains K learnable CLIP token embeddings against an audio dataset with all model weights frozen. The tokens are injected into the Sampler to guide generation toward a target style.
|
|
||||||
|
|
||||||
> **Note:** Textual inversion via the text conditioning path has limited effectiveness for fine-grained timbral style transfer in SelVA due to mean-pooling in the text conditioning path. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the current recommended approach.
|
|
||||||
|
|
||||||
**Outputs:** `embeddings_path` (STRING), `loss_curve` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Textual Inversion Loader
|
|
||||||
|
|
||||||
Loads CLIP token embeddings from a `.pt` file produced by the Textual Inversion Trainer. Connect to the Sampler's `textual_inversion` input.
|
|
||||||
|
|
||||||
**Output:** `textual_inversion` (TEXTUAL_INVERSION)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA TI Scheduler
|
|
||||||
|
|
||||||
Runs a series of Textual Inversion experiments from a JSON sweep file, reusing the encoded dataset across runs.
|
|
||||||
|
|
||||||
**Outputs:** `summary_path` (STRING), `comparison_curves` (IMAGE)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA Activation Steering Extractor
|
|
||||||
|
|
||||||
Computes per-block activation steering vectors from a training dataset by comparing DiT hidden states under BJ conditioning vs. empty conditioning. The resulting vectors can nudge the denoising trajectory toward the target style at inference.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `data_dir` | Directory with `.npz` feature files |
|
|
||||||
| `output_path` | Where to save `steering_vectors.pt` |
|
|
||||||
| `n_samples` | Clips to average over (default: 16) |
|
|
||||||
| `seed` | RNG seed |
|
| `seed` | RNG seed |
|
||||||
|
|
||||||
**Output:** `steering_path` (STRING)
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### SelVA Activation Steering Loader
|
### PrismAudio Text Only
|
||||||
|
|
||||||
Loads steering vectors from a `.pt` file produced by the Extractor. Connect to the Sampler's `steering_vectors` input.
|
Text-to-audio generation without video. Uses the T5-Gemma encoder.
|
||||||
|
|
||||||
**Output:** `steering_vectors` (STEERING_VECTORS)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA BigVGAN Trainer
|
|
||||||
|
|
||||||
Fine-tunes the BigVGAN vocoder (mel → waveform) on a set of target-style audio clips. Only the vocoder is modified — the DiT generator and VAE are completely untouched.
|
|
||||||
|
|
||||||
Default mode (`snake_alpha_only`) tunes only the ~27K per-channel α parameters in Snake/SnakeBeta activations, which directly control harmonic periodicity. With 0.024% of parameters trainable the model cannot produce spectral averaging artifacts regardless of loss function. See [STYLE_TRANSFER.md](STYLE_TRANSFER.md) for the full rationale.
|
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `model` | SELVA_MODEL |
|
| `model` | From Model Loader |
|
||||||
| `data_dir` | Directory with target-style audio files (searched recursively) |
|
| `text_prompt` | Chain-of-thought audio scene description. Longer, more detailed prompts produce better results. |
|
||||||
| `output_path` | Where to save the fine-tuned vocoder `.pt` |
|
| `duration` | Audio duration in seconds |
|
||||||
| `train_mode` | `snake_alpha_only` (default) or `all_params` |
|
| `steps` | Sampling steps (default: 100) |
|
||||||
| `steps` | Training steps (default: 2000) |
|
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
||||||
| `lr` | Learning rate (default: 1e-4 for snake_alpha_only) |
|
|
||||||
| `batch_size` | Clips per step (default: 4) |
|
|
||||||
| `segment_seconds` | Audio segment length per training sample (default: 1.0 s) |
|
|
||||||
| `lambda_l2sp` | L2-SP anchor regularization strength — penalizes drift from pretrained weights (default: 1e-3) |
|
|
||||||
| `save_every` | Checkpoint interval in steps (default: 500) |
|
|
||||||
| `seed` | RNG seed |
|
| `seed` | RNG seed |
|
||||||
| `discriminator_path` | *(optional)* Path to `bigvgan_discriminator_optimizer.pt` — when provided, frozen MPD+MRD feature matching replaces mel L1, directly penalizing harmonic smearing |
|
|
||||||
|
|
||||||
**Output:** `checkpoint_path` (STRING) — load with SelVA BigVGAN Loader
|
|
||||||
|
|
||||||
Saves eval samples and mel spectrogram PNGs at baseline, each checkpoint, and final.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA BigVGAN Loader
|
|
||||||
|
|
||||||
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer and replaces the vocoder weights in a SELVA_MODEL in-place. Connect the output to SelVA Sampler instead of the base Model Loader.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL from Model Loader |
|
|
||||||
| `path` | Path to fine-tuned vocoder `.pt` (relative = ComfyUI output directory) |
|
|
||||||
|
|
||||||
**Output:** `model` (SELVA_MODEL with fine-tuned vocoder)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### SelVA DITTO Optimizer
|
|
||||||
|
|
||||||
Inference-time noise optimization ([arXiv:2401.12179](https://arxiv.org/abs/2401.12179), ICML 2024 Oral). Optimizes the initial noise latent x₀ to make the generated audio match a set of BJ reference clips, by backpropagating a mel style loss through the ODE solver. All model weights remain frozen — zero quality degradation risk.
|
|
||||||
|
|
||||||
Style loss: mean spectrum + Gram matrix computed against reference mels. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference clips. Optimization runs only through the DiT + VAE decoder; the vocoder is only invoked for the final output pass.
|
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | SELVA_MODEL |
|
|
||||||
| `features` | From SelVA Feature Extractor |
|
|
||||||
| `prompt` | Sound description (leave empty to use features prompt) |
|
|
||||||
| `negative_prompt` | Sounds to suppress |
|
|
||||||
| `reference_dir` | Directory with BJ reference audio clips (.wav/.flac/.mp3) |
|
|
||||||
| `n_opt_steps` | Gradient optimization steps on x₀ (default: 50) |
|
|
||||||
| `opt_lr` | Adam LR for x₀ optimization (default: 0.1) |
|
|
||||||
| `n_ode_steps` | ODE steps per optimization iteration (default: 10; lower = faster) |
|
|
||||||
| `n_grad_steps` | ODE steps to differentiate through — truncated BPTT (default: 5) |
|
|
||||||
| `style_weight` | Style loss weight (default: 1.0; increase for stronger BJ shift) |
|
|
||||||
| `steps` | Euler steps for the final generation pass (default: 25) |
|
|
||||||
| `cfg_strength` | CFG scale (default: 4.5) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
| `normalize` | *(optional)* RMS normalize output (default: true) |
|
|
||||||
| `target_lufs` | *(optional)* Target RMS level in dBFS (default: -27) |
|
|
||||||
|
|
||||||
**Output:** `AUDIO`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Workflows
|
## Workflows
|
||||||
|
|
||||||
### Basic generation
|
### Video-to-Audio
|
||||||
|
|
||||||
```
|
```
|
||||||
VHS LoadVideo ──► SelVA Feature Extractor ─────────────────────► SelVA Sampler ──► Save Audio
|
VHS LoadVideo ──► PrismAudio Feature Extractor ──► PrismAudio Sampler ──► Save Audio
|
||||||
│ (video_info) ▲
|
(video_info) ──────────────────► (fps auto)
|
||||||
│ (features) ──────────────────────────────────►│
|
(features) ────────────────────► (features)
|
||||||
│ (prompt) ────────────────────────────────────►│
|
duration=0 ─────────────────────► (auto from features)
|
||||||
```
|
```
|
||||||
|
|
||||||
### DITTO style transfer (recommended first approach)
|
### Pre-computed Features
|
||||||
|
|
||||||
```
|
```
|
||||||
SelVA Model Loader ─────────────────────────────────────────────► SelVA DITTO Optimizer ──► Save Audio
|
PrismAudio Feature Loader (.npz) ──► PrismAudio Sampler ──► Save Audio
|
||||||
▲
|
|
||||||
SelVA Feature Extractor ──(features)────────────────────────────────────►│
|
|
||||||
(prompt) ──────────────────────────────────────►│
|
|
||||||
BJ reference_dir ───────────────────────────────────────────────────────►│
|
|
||||||
```
|
```
|
||||||
|
|
||||||
No training required. Each run optimizes x₀ independently for the current video and reference set.
|
### Text-to-Audio
|
||||||
|
|
||||||
### Vocoder fine-tuning
|
|
||||||
|
|
||||||
```
|
```
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► (checkpoint .pt)
|
PrismAudio Text Only ──► Save Audio
|
||||||
▲
|
|
||||||
BJ audio clips ──(data_dir)──►│
|
|
||||||
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler ──► Save Audio
|
|
||||||
▲ ▲
|
|
||||||
checkpoint .pt SelVA Feature Extractor
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### LoRA training
|
## HuggingFace Authentication
|
||||||
|
|
||||||
See [LORA_TRAINING.md](LORA_TRAINING.md).
|
Required for T5-Gemma (gated model) and PrismAudio weights.
|
||||||
|
|
||||||
---
|
1. Visit <https://huggingface.co/FunAudioLLM/PrismAudio> and accept the license.
|
||||||
|
2. Authenticate via one of:
|
||||||
|
- **Environment variable:** `export HF_TOKEN=hf_...`
|
||||||
|
- **CLI login:** `huggingface-cli login`
|
||||||
|
|
||||||
## Installation
|
There is no `hf_token` widget on the main nodes by design — ComfyUI saves all STRING values to workflow JSON, which would expose your token. The Feature Extractor has an `hf_token` input as a convenience but using `HF_TOKEN` env var is preferred.
|
||||||
|
|
||||||
```bash
|
## Model Files
|
||||||
cd ComfyUI/custom_nodes
|
|
||||||
git clone https://github.com/Ethanfel/ComfyUI-SelVA.git
|
|
||||||
pip install -r ComfyUI-SelVA/requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
Weights are auto-downloaded to `ComfyUI/models/prismaudio/`:
|
||||||
|
|
||||||
## Model Weights
|
|
||||||
|
|
||||||
Weights are auto-downloaded to `ComfyUI/models/selva/` on first load. No manual setup required.
|
|
||||||
|
|
||||||
| File | Size | Description |
|
| File | Size | Description |
|
||||||
|------|------|-------------|
|
|------|------|-------------|
|
||||||
| `video_enc_sup_5.pth` | ~300 MB | TextSynchformer encoder |
|
| `prismaudio.ckpt` | ~2.7 GB | Diffusion model (DiT) |
|
||||||
| `generator_small_16k_sup_5.pth` | ~340 MB | Small generator, 16 kHz output |
|
| `vae.ckpt` | ~2.5 GB | Stable Audio 2.0 VAE |
|
||||||
| `generator_small_44k_sup_5.pth` | ~340 MB | Small generator, 44.1 kHz output |
|
| `synchformer_state_dict.pth` | ~950 MB | Synchformer visual encoder |
|
||||||
| `generator_medium_44k_sup_5.pth` | ~860 MB | Medium generator, 44.1 kHz output |
|
|
||||||
| `generator_large_44k_sup_5.pth` | ~2.0 GB | Large generator, 44.1 kHz output |
|
|
||||||
| `v1-16.pth` | ~1.1 GB | VAE for 16 kHz |
|
|
||||||
| `v1-44.pth` | ~1.1 GB | VAE for 44.1 kHz |
|
|
||||||
| `best_netG.pt` | ~90 MB | BigVGAN vocoder for 16 kHz |
|
|
||||||
| `synchformer_state_dict.pth` | ~950 MB | Synchformer (shared with PrismAudio if present) |
|
|
||||||
|
|
||||||
CLIP (DFN5B-ViT-H-14-384) and T5 (flan-t5-base) are downloaded automatically from HuggingFace to `~/.cache/huggingface/`.
|
T5-Gemma and VideoPrism LvT are cached in `~/.cache/huggingface/`.
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## VRAM Requirements
|
## VRAM Requirements
|
||||||
|
|
||||||
| VRAM | Recommended settings |
|
| VRAM | Recommended settings |
|
||||||
|------|----------------------|
|
|------|----------------------|
|
||||||
| 24 GB+ | `keep_in_vram`, any variant |
|
| 24 GB+ | `keep_in_vram`, any precision |
|
||||||
| 12–24 GB | `offload_to_cpu`, medium or smaller |
|
| 12–24 GB | `offload_to_cpu`, bf16/fp16 |
|
||||||
| 8–12 GB | `offload_to_cpu`, small variant, fp16 |
|
| 8–12 GB | `offload_to_cpu`, fp16 |
|
||||||
|
| < 8 GB | May work with `offload_to_cpu` + fp16 |
|
||||||
|
|
||||||
The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available, otherwise `offload_to_cpu`.
|
## Troubleshooting
|
||||||
|
|
||||||
---
|
- **Gated model errors** — Accept the license at <https://huggingface.co/FunAudioLLM/PrismAudio> and set `HF_TOKEN`.
|
||||||
|
- **VRAM errors** — Switch `offload_strategy` to `offload_to_cpu` and/or use `fp16` precision.
|
||||||
## Style Transfer
|
- **Feature extraction fails** — Ensure `synchformer_state_dict.pth` is in `models/prismaudio/`. On first run with `managed_env`, installation takes several minutes.
|
||||||
|
- **flash-attn** — Optional. Auto-detected at runtime; falls back to PyTorch SDPA.
|
||||||
For adapting SelVA to a specific audio style (e.g. BJ / Bladee / Jersey Club), see [STYLE_TRANSFER.md](STYLE_TRANSFER.md).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
|
PrismAudio by [FunAudioLLM](https://github.com/FunAudioLLM) (ICLR 2026). [Model & weights](https://huggingface.co/FunAudioLLM/PrismAudio).
|
||||||
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
|
|
||||||
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
|
|
||||||
- [DITTO](https://arxiv.org/abs/2401.12179) by Novack et al. — inference-time diffusion optimization
|
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
# Style Transfer for SelVA
|
|
||||||
|
|
||||||
This document covers approaches for adapting SelVA's audio output to a specific timbral style using a small reference dataset (~50 clips). The context here is BJ / Bladee / Jersey Club style — sharp metallic transients, saturated harmonics, 808 sub bass, glassy high-frequency content — but the methods apply to any style target.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why standard fine-tuning is hard
|
|
||||||
|
|
||||||
SelVA's generation quality depends on the DiT (generator) outputting latents that fall in the high-density region of the VAE decoder's training distribution. BJ's audio maps to a sparse, tail region of that space — the VAE roundtrip already shows ~10–15 dB elevated HF noise floor on BJ material. Any training that pushes the generator toward exact BJ encoder outputs is training toward an already-degraded target.
|
|
||||||
|
|
||||||
**LoRA** makes this worse: it introduces "intruder dimensions" — new high-rank singular vectors absent from the pretrained weight spectrum — that push DiT outputs further off-manifold. This mechanism is LR- and scale-independent. Reducing LoRA scale does not fix the direction, only the magnitude. Empirically: spectral flatness degrades to ~0.21–0.26 (vs. baseline 0.013) at every scale from 0.0625 to 1.0.
|
|
||||||
|
|
||||||
**Textual inversion** via the text conditioning path suffers from mean-pooling: SelVA's text features are pooled into a single global vector before injection into the DiT. The optimizer finds a spectral bias (noise/buzz) as the cheapest way to reduce reconstruction loss — not a semantic style shift.
|
|
||||||
|
|
||||||
The approaches below are ordered by expected quality and ease of use.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 1 — DITTO (recommended first try)
|
|
||||||
|
|
||||||
**Node: SelVA DITTO Optimizer**
|
|
||||||
|
|
||||||
Inference-time noise optimization. Keeps all model weights frozen and only optimizes the initial noise latent x₀ using a style loss computed against the reference clips. Since the weights never change, there is zero risk of quality degradation — the model still generates from its original manifold, just from a better starting point.
|
|
||||||
|
|
||||||
**Style loss:** mean spectrum + Gram matrix of mel spectrograms. The Gram matrix captures covariance between frequency bands (timbral texture) without requiring temporal alignment with the reference. Optimization runs entirely before the vocoder — BigVGAN is only called for the final output pass.
|
|
||||||
|
|
||||||
**How it works:**
|
|
||||||
|
|
||||||
For each video clip you want to process:
|
|
||||||
1. Run SelVA Feature Extractor as usual.
|
|
||||||
2. Instead of SelVA Sampler, connect to **SelVA DITTO Optimizer** with your BJ `reference_dir`.
|
|
||||||
3. The node runs N optimization steps, each backpropagating through the last few ODE Euler steps to compute `∂loss/∂x₀`.
|
|
||||||
4. After optimization, one final full-ODE pass generates the output audio from the refined x₀.
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ────────────────────────────────► SelVA DITTO Optimizer ──► audio
|
|
||||||
▲
|
|
||||||
SelVA Feature Extractor ──(features)────────────────────────►│
|
|
||||||
(prompt) ──────────────────────────►│
|
|
||||||
BJ clips ───────────────────────────(reference_dir) ─────────►│
|
|
||||||
```
|
|
||||||
|
|
||||||
**Tuning guide:**
|
|
||||||
|
|
||||||
| Parameter | Starting value | When to adjust |
|
|
||||||
|---|---|---|
|
|
||||||
| `n_opt_steps` | 50 | Increase to 100–200 if style shift is too subtle |
|
|
||||||
| `opt_lr` | 0.1 | Lower to 0.05 if coherence breaks; raise to 0.3 for stronger shift |
|
|
||||||
| `n_ode_steps` | 10 | Lower = faster optimization, less accurate gradient |
|
|
||||||
| `n_grad_steps` | 5 | Number of ODE steps to differentiate through — must be ≤ n_ode_steps |
|
|
||||||
| `style_weight` | 1.0 | Increase to 2–5 for stronger BJ character; watch for incoherence |
|
|
||||||
|
|
||||||
**Memory:** Each opt step stores activations for `n_grad_steps` DiT forward passes with gradient checkpointing. At n_grad_steps=5, expect ~4–6 GB additional VRAM over baseline inference.
|
|
||||||
|
|
||||||
**Time per video clip:** ~50 opt steps × (10 ODE steps × 2 passes for checkpointing) + 25 final steps ≈ 5–15 minutes depending on GPU.
|
|
||||||
|
|
||||||
**Limitations:** DITTO with mel Gram matrix loss shifts timbral statistics but cannot precisely match the BJ transient sharpness — the Gram matrix is a texture descriptor, not a transient detector. See Tier 2 (vocoder fine-tuning) for that.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 2 — Vocoder Fine-tuning
|
|
||||||
|
|
||||||
**Nodes: SelVA BigVGAN Trainer → SelVA BigVGAN Loader**
|
|
||||||
|
|
||||||
The BigVGAN vocoder (mel → waveform) is the component most responsible for the final timbral character of the output. Fine-tuning only the vocoder keeps the DiT completely untouched — latents stay on-manifold, only the waveform rendering changes.
|
|
||||||
|
|
||||||
### Why plain mel L1 loss fails
|
|
||||||
|
|
||||||
BigVGAN was trained with `L_G = Σ[L_adv + 2·L_fm] + 45·L_mel`. The adversarial and feature-matching terms do the perceptual heavy lifting — they prevent the generator from averaging over high-variance harmonic content. Dropping them for a plain mel L1 loss is a loss-function topology problem: the model minimizes expected reconstruction error by averaging over harmonic uncertainty, eroding the saturated 3–8 kHz harmonics visible as "green smear" in spectrograms. This happens regardless of LR or step count.
|
|
||||||
|
|
||||||
### `snake_alpha_only` mode (default, recommended)
|
|
||||||
|
|
||||||
BigVGAN's AMP blocks use Snake/SnakeBeta activations: `y = x + (1/α)·sin²(α·x)` where α is a per-channel learnable scalar. Alpha parameters directly control the harmonic periodicity of each layer's output — they are the "harmonic tuning knobs" of the vocoder.
|
|
||||||
|
|
||||||
With `train_mode=snake_alpha_only`, only the ~27K alpha parameters (0.024% of the 112M parameter model) are trained. The conv weights encoding waveform structure remain frozen. With this few trainable parameters the model physically cannot reshape the spectrum significantly regardless of loss function — no capacity for the green smear.
|
|
||||||
|
|
||||||
**Loss in snake_alpha_only mode:** mel L1 + multi-resolution STFT L1 are still used but can only shift harmonic emphasis, not spectral shape.
|
|
||||||
|
|
||||||
### `all_params` mode with discriminator
|
|
||||||
|
|
||||||
For a stronger shift — or to use proper perceptual losses — run with `train_mode=all_params` and provide a `discriminator_path` (the `bigvgan_discriminator_optimizer.pt` from the BigVGAN pretrained release):
|
|
||||||
|
|
||||||
1. The frozen pretrained MPD and MRD discriminators are loaded and used as fixed perceptual feature extractors.
|
|
||||||
2. Loss becomes `2·L_fm(frozen_D) + 0.1·L_mel` — feature matching directly penalizes harmonic smearing through the discriminator's learned perceptual space.
|
|
||||||
3. `lambda_l2sp` (default 1e-3) anchors all parameters to their pretrained values — prevents catastrophic drift on 50 clips.
|
|
||||||
|
|
||||||
This is the highest-quality vocoder fine-tuning path but requires the discriminator checkpoint.
|
|
||||||
|
|
||||||
### Workflow
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Trainer ──► bigvgan_bj.pt
|
|
||||||
▲
|
|
||||||
BJ audio clips ──(data_dir)──►│
|
|
||||||
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA Sampler
|
|
||||||
▲ ▲
|
|
||||||
bigvgan_bj.pt SelVA Feature Extractor
|
|
||||||
```
|
|
||||||
|
|
||||||
### Tuning guide
|
|
||||||
|
|
||||||
| Parameter | Default | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| `train_mode` | snake_alpha_only | Safe default; use all_params only with discriminator_path |
|
|
||||||
| `steps` | 2000 | 1000–2000 for snake_alpha_only; 3000–5000 for all_params |
|
|
||||||
| `lr` | 1e-4 | For snake_alpha_only; lower to 1e-5 for all_params |
|
|
||||||
| `lambda_l2sp` | 1e-3 | Increase to 1e-2 for all_params to limit drift |
|
|
||||||
| `batch_size` | 4 | 4–8 for stable gradients |
|
|
||||||
| `segment_seconds` | 1.0 | 1–2 s segments recommended |
|
|
||||||
|
|
||||||
**Eval samples:** The trainer saves `.wav` and mel spectrogram `.png` files at baseline, each checkpoint, and final. Compare the spectrograms — saturation (red values in high-frequency bands) should increase relative to baseline.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tier 3 — DITTO + Vocoder (combined)
|
|
||||||
|
|
||||||
Stack both:
|
|
||||||
|
|
||||||
```
|
|
||||||
SelVA Model Loader ──► SelVA BigVGAN Loader ──► SelVA DITTO Optimizer ──► audio
|
|
||||||
▲ ▲
|
|
||||||
bigvgan_bj.pt SelVA Feature Extractor + reference_dir
|
|
||||||
```
|
|
||||||
|
|
||||||
The fine-tuned vocoder handles waveform rendering; DITTO shifts the latent trajectory. Each addresses a different aspect of style transfer.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## What doesn't work (and why)
|
|
||||||
|
|
||||||
### Standard LoRA
|
|
||||||
|
|
||||||
LoRA introduces "intruder dimensions" — high-rank singular vectors absent from the pretrained weight spectrum — at initialization. These push DiT outputs into decoder-hostile latent regions regardless of scale or LR. The failure is direction-based, not magnitude-based, so reducing LoRA scale does not fix it.
|
|
||||||
|
|
||||||
PiSSA initialization (`init_lora_weights="pissa"`) and rsLoRA scaling (`use_rslora=True`) reduce intruder dimension formation by starting in the pretrained weight subspace. These are planned as future improvements.
|
|
||||||
|
|
||||||
### Textual inversion
|
|
||||||
|
|
||||||
SelVA mean-pools all 77 CLIP tokens into a single AdaLN bias vector. Every token contributes equally to a scalar offset; the optimizer finds spectral buzz as the minimum-cost way to reduce flow-matching reconstruction loss. More tokens make it worse.
|
|
||||||
|
|
||||||
### Activation steering (global mean difference)
|
|
||||||
|
|
||||||
The raw mean difference between BJ and empty conditions is not a clean style basis — it carries noise from the diversity of the training clips and the many attention blocks that have nothing to do with timbral character. Global injection (all blocks at any strength) kills the sound. Targeted layer injection (only the 3–6 blocks most predictive of BJ style) is theoretically sound but requires per-layer delta magnitude ranking to identify the right layers first.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Reference dataset preparation
|
|
||||||
|
|
||||||
Use the same audio clips for both DITTO and vocoder fine-tuning:
|
|
||||||
|
|
||||||
- **Minimum:** 20–30 clips. DITTO works from 5+; vocoder benefits from 40+.
|
|
||||||
- **Format:** `.wav` or `.flac` at native sample rate. The trainer resamples automatically.
|
|
||||||
- **Length:** Any length ≥ 1 s. Longer is fine — the trainer segments internally.
|
|
||||||
- **Quality:** Clean, full-mix BJ clips. Avoid heavily compressed or streaming-ripped files. Use HF Smoother if HF content sounds brittle after VAE roundtrip.
|
|
||||||
- **Diversity:** Vary tempo, key, vocal density. 20 diverse clips > 50 copies of the same 8-bar loop.
|
|
||||||
|
|
||||||
Normalize all clips to consistent loudness (e.g. -14 LUFS) before training. Inconsistent levels increase loss variance and slow convergence.
|
|
||||||
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
|
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
PrismAudio feature extraction utilities.
|
||||||
|
|
||||||
|
Implements FeaturesUtils used by scripts/extract_features.py to extract:
|
||||||
|
- Text features via T5-Gemma (transformers)
|
||||||
|
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
|
||||||
|
- Sync features via Synchformer visual encoder (PyTorch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class FeaturesUtils:
|
||||||
|
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
|
||||||
|
self.device = device or torch.device("cpu")
|
||||||
|
self._t5_tokenizer = None
|
||||||
|
self._t5_encoder = None
|
||||||
|
self._vp_model = None
|
||||||
|
self._vp_state = None
|
||||||
|
self._vp_text_tokenizer = None
|
||||||
|
self._sync_model = None
|
||||||
|
|
||||||
|
self._synchformer_ckpt = synchformer_ckpt
|
||||||
|
self._load_synchformer()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# T5-Gemma text encoding
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _ensure_t5(self):
|
||||||
|
if self._t5_encoder is not None:
|
||||||
|
return
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
model_id = "google/t5gemma-l-l-ul2-it"
|
||||||
|
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
|
||||||
|
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
self._t5_encoder = (
|
||||||
|
AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||||
|
.get_encoder()
|
||||||
|
.to(self.device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_t5_text(self, texts):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
texts: list of str
|
||||||
|
Returns:
|
||||||
|
Tensor [seq_len, 1024]
|
||||||
|
"""
|
||||||
|
self._ensure_t5()
|
||||||
|
tokens = self._t5_tokenizer(
|
||||||
|
texts, return_tensors="pt", padding=True
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = self._t5_encoder(**tokens)
|
||||||
|
# Move encoder off GPU to save VRAM
|
||||||
|
self._t5_encoder.to("cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# VideoPrism video + text encoding (JAX)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _ensure_videoprism(self):
|
||||||
|
if self._vp_model is not None:
|
||||||
|
return
|
||||||
|
from videoprism import models as vp
|
||||||
|
import jax
|
||||||
|
model_name = "videoprism_lvt_public_v1_large"
|
||||||
|
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
|
||||||
|
self._vp_model = vp.get_model(model_name)
|
||||||
|
self._vp_state = vp.load_pretrained_weights(model_name)
|
||||||
|
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
|
||||||
|
jax_dev = jax.devices()[0]
|
||||||
|
self._jax_forward = jax.jit(
|
||||||
|
lambda x, y, z: self._vp_model.apply(
|
||||||
|
self._vp_state, x, y, z, train=False, return_intermediate=True
|
||||||
|
),
|
||||||
|
device=jax_dev,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||||
|
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
|
||||||
|
Returns:
|
||||||
|
global_video_features: Tensor [1, D]
|
||||||
|
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
|
||||||
|
global_text_features: Tensor [1, D]
|
||||||
|
"""
|
||||||
|
self._ensure_videoprism()
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from videoprism import models as vp
|
||||||
|
|
||||||
|
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
||||||
|
frames = clip_input.squeeze(0) # [T, C, H, W]
|
||||||
|
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
|
||||||
|
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
||||||
|
frames_np = frames.cpu().numpy().astype(np.float32)
|
||||||
|
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
||||||
|
|
||||||
|
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
|
||||||
|
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
|
||||||
|
|
||||||
|
# Joint video+text forward with intermediate outputs
|
||||||
|
video_embeddings, text_embeddings, outputs = self._jax_forward(
|
||||||
|
frames_jax, text_ids, text_paddings
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-frame features: [B, T, 1024] L2-normalized
|
||||||
|
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
|
||||||
|
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
|
||||||
|
|
||||||
|
# Global video embedding: [1024] → [1, 1024]
|
||||||
|
global_video = torch.from_numpy(
|
||||||
|
np.array(video_embeddings[0])
|
||||||
|
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||||
|
|
||||||
|
# Global text embedding: [1024] → [1, 1024]
|
||||||
|
global_text = torch.from_numpy(
|
||||||
|
np.array(text_embeddings[0])
|
||||||
|
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||||
|
|
||||||
|
return global_video, per_frame, global_text
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Synchformer sync feature encoding
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_synchformer(self):
|
||||||
|
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
|
||||||
|
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
|
||||||
|
if isinstance(state, dict) and "model" in state:
|
||||||
|
state_dict = state["model"]
|
||||||
|
else:
|
||||||
|
state_dict = state
|
||||||
|
|
||||||
|
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
|
||||||
|
self._sync_model.eval()
|
||||||
|
|
||||||
|
def encode_video_with_sync(self, sync_input):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||||
|
Returns:
|
||||||
|
sync_features: Tensor [num_segments, 768]
|
||||||
|
"""
|
||||||
|
if self._sync_model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"[FeaturesUtils] Synchformer checkpoint not loaded. "
|
||||||
|
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
|
||||||
|
)
|
||||||
|
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
|
||||||
|
with torch.no_grad():
|
||||||
|
return self._sync_model(frames)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Synchformer visual encoder — TimeSformer-style ViT-B/16
|
||||||
|
# Architecture reverse-engineered from synchformer_state_dict.pth
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class _PatchEmbed(nn.Module):
|
||||||
|
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class _ViTAttn(nn.Module):
|
||||||
|
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
|
||||||
|
def __init__(self, dim=768, num_heads=12):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, D = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
|
||||||
|
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
|
||||||
|
|
||||||
|
|
||||||
|
class _BlockMLP(nn.Module):
|
||||||
|
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
|
||||||
|
def __init__(self, dim=768, mlp_dim=3072):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Linear(dim, mlp_dim)
|
||||||
|
self.fc2 = nn.Linear(mlp_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc2(F.gelu(self.fc1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class _TimeSformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Factorized space-time attention block.
|
||||||
|
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
|
||||||
|
"""
|
||||||
|
def __init__(self, dim=768, num_heads=12):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.attn = _ViTAttn(dim, num_heads)
|
||||||
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
|
self.timeattn = _ViTAttn(dim, num_heads)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
self.mlp = _BlockMLP(dim)
|
||||||
|
|
||||||
|
def forward(self, x, T):
|
||||||
|
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
# Temporal attention: for each spatial position, attend across T frames
|
||||||
|
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
|
||||||
|
xt = x.permute(1, 0, 2)
|
||||||
|
xt = xt + self.timeattn(self.norm3(xt))
|
||||||
|
x = xt.permute(1, 0, 2)
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _SpatialAttnAgg(nn.Module):
|
||||||
|
"""
|
||||||
|
Aggregates 196 spatial patches → 1 feature per frame using a
|
||||||
|
TransformerEncoderLayer with a learnable CLS token.
|
||||||
|
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
|
||||||
|
"""
|
||||||
|
def __init__(self, dim=768, num_heads=12):
|
||||||
|
super().__init__()
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
||||||
|
self.linear1 = nn.Linear(dim, dim * 4)
|
||||||
|
self.linear2 = nn.Linear(dim * 4, dim)
|
||||||
|
self.norm1 = nn.LayerNorm(dim)
|
||||||
|
self.norm2 = nn.LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x: [T, 196, 768] — spatial patches (CLS stripped)
|
||||||
|
T = x.shape[0]
|
||||||
|
cls = self.cls_token.expand(T, -1, -1)
|
||||||
|
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
|
||||||
|
xn = self.norm1(x)
|
||||||
|
x = x + self.self_attn(xn, xn, xn)[0]
|
||||||
|
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
||||||
|
return x[:, 0, :] # [T, 768] — CLS per frame
|
||||||
|
|
||||||
|
|
||||||
|
class _SynchformerVisualEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
|
||||||
|
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_dict, device):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.segment_frames = 8
|
||||||
|
|
||||||
|
self.patch_embed = _PatchEmbed()
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
|
||||||
|
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
|
||||||
|
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
|
||||||
|
self.norm = nn.LayerNorm(768)
|
||||||
|
self.spatial_attn_agg = _SpatialAttnAgg()
|
||||||
|
|
||||||
|
# Load weights from vfeat_extractor.* prefix
|
||||||
|
prefix = "vfeat_extractor."
|
||||||
|
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||||
|
# Exclude 3D patch embed (we use 2D only)
|
||||||
|
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
|
||||||
|
missing, unexpected = self.load_state_dict(sub, strict=False)
|
||||||
|
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
|
||||||
|
if missing:
|
||||||
|
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
|
||||||
|
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
def forward(self, frames):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
|
||||||
|
Returns:
|
||||||
|
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
|
||||||
|
"""
|
||||||
|
T = frames.shape[0]
|
||||||
|
seg = self.segment_frames
|
||||||
|
num_seg = max(1, T // seg)
|
||||||
|
T_aligned = num_seg * seg
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(num_seg):
|
||||||
|
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
|
||||||
|
results.append(self._forward_segment(chunk))
|
||||||
|
return torch.cat(results, dim=0) # [T_aligned, 768]
|
||||||
|
|
||||||
|
def _forward_segment(self, x):
|
||||||
|
# x: [8, 3, 224, 224]
|
||||||
|
T = x.shape[0] # 8
|
||||||
|
|
||||||
|
# Patch embedding + CLS token
|
||||||
|
x = self.patch_embed(x) # [8, 196, 768]
|
||||||
|
cls = self.cls_token.expand(T, -1, -1)
|
||||||
|
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
|
||||||
|
|
||||||
|
# Positional + temporal embeddings
|
||||||
|
x = x + self.pos_embed # broadcast (1,197,768)
|
||||||
|
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
|
||||||
|
|
||||||
|
# Transformer blocks (factorized space-time)
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, T)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
# Aggregate spatial patches → 1 feature per frame
|
||||||
|
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
# Audio Dataset Pipeline for Generative Model Training
|
|
||||||
|
|
||||||
Research notes on audio cleaning, augmentation, and quality metrics for LoRA fine-tuning of MMAudio/SelVA. Based on papers and tooling survey (April 2026).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Core Principle
|
|
||||||
|
|
||||||
Augmentation for generative models ≠ augmentation for classifiers.
|
|
||||||
The goal is **not invariance** — it is expanding the training manifold so the model learns the distribution of a sound rather than memorizing a fixed set of waveforms.
|
|
||||||
|
|
||||||
With 10 clips, velocity field collapse (arXiv:2410.23594) is mathematically expected: the flow-matching model memorizes the training trajectories instead of generalizing. More diverse data is the only real fix.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Recommended Pipeline
|
|
||||||
|
|
||||||
### Step 1 — Quality Screening
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Clipping check
|
|
||||||
clip_ratio = np.sum(np.abs(audio) >= 0.99) / len(audio) # flag if > 0.1%
|
|
||||||
|
|
||||||
# DC offset check + removal
|
|
||||||
dc = np.mean(audio)
|
|
||||||
audio -= dc
|
|
||||||
|
|
||||||
# LUFS normalization to -14 LUFS (essential for training consistency)
|
|
||||||
# pip install pyloudnorm
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
loudness = meter.integrated_loudness(audio)
|
|
||||||
audio = pyln.normalize.loudness(audio, loudness, -14.0)
|
|
||||||
# Or via ffmpeg: ffmpeg -af loudnorm=I=-14:LRA=7:TP=-1
|
|
||||||
|
|
||||||
# DNSMOS quality gate (discard if OVRL < 3.5 for training; < 2.5 is unusable)
|
|
||||||
# from Microsoft DNS-Challenge repo
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2 — Cleaning
|
|
||||||
|
|
||||||
| Tool | Install | Use |
|
|
||||||
|---|---|---|
|
|
||||||
| **AudioSep** | `pip install audiosep` | Isolate target sound from background — most impactful tool |
|
|
||||||
| **noisereduce** | `pip install noisereduce` | Light stationary/non-stationary denoising, preserves character |
|
|
||||||
| **librosa** | `pip install librosa` | Silence trimming: `librosa.effects.trim(audio, top_db=30)` |
|
|
||||||
| **torchaudio.transforms.Fade** | (torchaudio) | Prevent click artifacts at clip edges |
|
|
||||||
| **DeepFilterNet** | `pip install deepfilternet` | Heavy denoising — good for speech, may alter tonal sounds |
|
|
||||||
|
|
||||||
**AudioSep usage:**
|
|
||||||
```python
|
|
||||||
from audiosep import AudioSep
|
|
||||||
model = AudioSep.from_pretrained("audio-agi/audiosep")
|
|
||||||
# ~1.5 GB checkpoint, ~4 GB VRAM
|
|
||||||
model.inference(audio_path, "a dog barking loudly", output_path)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3 — Waveform Augmentation (10 clips → 50–100)
|
|
||||||
|
|
||||||
Apply stochastically per clip:
|
|
||||||
|
|
||||||
| Transform | Params | Notes |
|
|
||||||
|---|---|---|
|
|
||||||
| **PitchShift** | ±1–3 semitones | 3 variants per clip. Limit to ±1 st for tonal/pitched sounds |
|
|
||||||
| **ApplyImpulseResponse** | 5 different RIRs | 5 variants per clip — EchoThief (~150 free IRs) or pyroomacoustics |
|
|
||||||
| **LoudnessNormalization** | ±2 dB random | Subtle level variation |
|
|
||||||
| **SevenBandParametricEQ** | ±3 dB | Gentle spectral variation |
|
|
||||||
| **TimeStretch** | 0.9–1.1× only | Do NOT use 2× to pad short clips — breaks video sync |
|
|
||||||
|
|
||||||
```python
|
|
||||||
# pip install audiomentations pedalboard pyroomacoustics
|
|
||||||
import audiomentations as A
|
|
||||||
|
|
||||||
augment = A.Compose([
|
|
||||||
A.PitchShift(min_semitones=-2, max_semitones=2, p=0.5),
|
|
||||||
A.ApplyImpulseResponse(ir_paths="path/to/irs/", p=0.5),
|
|
||||||
A.SevenBandParametricEQ(min_gain_db=-3, max_gain_db=3, p=0.3),
|
|
||||||
A.LoudnessNormalization(min_lufs=-16, max_lufs=-12, p=0.5),
|
|
||||||
A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3),
|
|
||||||
])
|
|
||||||
audio_aug = augment(samples=audio, sample_rate=sr)
|
|
||||||
```
|
|
||||||
|
|
||||||
**RIR sources:**
|
|
||||||
- EchoThief: ~150 free real-world IRs (churches, caves, parking garages)
|
|
||||||
- pyroomacoustics: synthetic room simulation, fully controllable
|
|
||||||
|
|
||||||
### Step 4 — Latent Augmentation (at training time)
|
|
||||||
|
|
||||||
After VAE encoding:
|
|
||||||
|
|
||||||
**Latent mixup** between same-category pairs:
|
|
||||||
```python
|
|
||||||
# Mix latents BEFORE flow-matching noise is added
|
|
||||||
# Only mix clips from the same sound category — cross-category mixing produces garbage
|
|
||||||
lam = torch.distributions.Beta(0.4, 0.4).sample()
|
|
||||||
z_mix = lam * z1 + (1 - lam) * z2
|
|
||||||
```
|
|
||||||
With 10 clips: C(10,2) = 45 possible pairs → significant expansion without new recordings.
|
|
||||||
|
|
||||||
**Small Gaussian noise:**
|
|
||||||
```python
|
|
||||||
z_noised = z + torch.randn_like(z) * 0.02 * z.std()
|
|
||||||
```
|
|
||||||
Prevents trivial memorization of exact latent coordinates.
|
|
||||||
|
|
||||||
MusicLDM (arXiv:2308.01546) shows latent mixup > waveform mixup for generative quality.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Transforms to AVOID for Generative Training
|
|
||||||
|
|
||||||
| Transform | Why |
|
|
||||||
|---|---|
|
|
||||||
| ClippingDistortion, BitCrush, TanhDistortion, Mp3Compression | Model learns the artifact |
|
|
||||||
| Reverse | Breaks temporal structure for video-to-audio task |
|
|
||||||
| TimeMask (creating silence gaps) | Unnatural — model learns to produce silence |
|
|
||||||
| TimeStretch > 1.3× | Phase vocoder artifacts become part of the target distribution |
|
|
||||||
| Heavy background noise (< 15 dB SNR) | Model learns to reproduce the noise |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quality Metrics
|
|
||||||
|
|
||||||
| Metric | Tool | Threshold |
|
|
||||||
|---|---|---|
|
|
||||||
| DNSMOS P.835 (SIG/BAK/OVRL) | Microsoft DNS-Challenge | OVRL > 3.5 for training |
|
|
||||||
| LUFS | pyloudnorm | Normalize all clips to -14 LUFS |
|
|
||||||
| WADA-SNR | (standalone) | No-reference SNR estimate |
|
|
||||||
| Clipping ratio | NumPy | Flag if > 0.1% of samples at ±0.99 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tool Reference
|
|
||||||
|
|
||||||
| Tool | Install | Purpose |
|
|
||||||
|---|---|---|
|
|
||||||
| audiomentations | `pip install audiomentations` | Primary augmentation library |
|
|
||||||
| pedalboard | `pip install pedalboard` | Higher quality pitch shift, IR convolution |
|
|
||||||
| AudioSep | `pip install audiosep` | Source separation / isolation |
|
|
||||||
| noisereduce | `pip install noisereduce` | Non-stationary denoising |
|
|
||||||
| DeepFilterNet | `pip install deepfilternet` | Heavy denoising (speech-optimized) |
|
|
||||||
| pyloudnorm | `pip install pyloudnorm` | LUFS normalization |
|
|
||||||
| Silero VAD | `pip install silero-vad` | Voice/silence detection |
|
|
||||||
| pyroomacoustics | `pip install pyroomacoustics` | Synthetic RIR generation |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Integration with PrismAudio / SelVA
|
|
||||||
|
|
||||||
No established ComfyUI audio preprocessing ecosystem as of early 2026. Build thin wrapper nodes around the tools above. PrismAudio already has all required patterns (subprocess isolation, AUDIO type transport).
|
|
||||||
|
|
||||||
**Target node set:**
|
|
||||||
- `SelVA Dataset Cleaner` — wraps noisereduce + LUFS normalization + trim + DNSMOS gate
|
|
||||||
- `SelVA Dataset Augmenter` — wraps audiomentations Compose pipeline
|
|
||||||
|
|
||||||
Steps 1–3 are preprocessing (run once before feature extraction).
|
|
||||||
Step 4 (latent mixup) is a training loop modification — integrate into `selva_lora_trainer.py`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Papers
|
|
||||||
|
|
||||||
| Paper | ArXiv | Finding |
|
|
||||||
|---|---|---|
|
|
||||||
| MusicLDM | 2308.01546 | Latent mixup > waveform mixup for generative quality |
|
|
||||||
| EDMSound | 2311.08667 | Memorization documented — same failure mode as 10-clip training |
|
|
||||||
| Synthio | 2410.02056 | Synthetic audio as augmentation data (ICLR 2025) |
|
|
||||||
| HunyuanVideo-Foley | 2508.16930 | V2A data pipeline at scale (100K hrs) |
|
|
||||||
| FM memorization | 2410.23594 | Velocity field collapse theory — proves early overfitting on small datasets |
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
# AudioX vs SelVA — Evaluation
|
|
||||||
|
|
||||||
AudioX (arXiv:2503.10522, ICLR 2026) is a unified multimodal audio generation model from HKUST.
|
|
||||||
This document compares it against SelVA/MMAudio and assesses the cost of adding it to PrismAudio.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quick Decision Guide
|
|
||||||
|
|
||||||
| Situation | Use |
|
|
||||||
|---|---|
|
|
||||||
| Video → realistic sound effects | **SelVA** — faster, purpose-built, MIT license |
|
|
||||||
| Music generation from video or text | **AudioX** — SelVA cannot do this |
|
|
||||||
| Audio inpainting / music continuation | **AudioX** — SelVA cannot do this |
|
|
||||||
| LoRA fine-tuning on a custom sound | **SelVA** — full training infrastructure already exists |
|
|
||||||
| Variable output duration | **AudioX** — SelVA is fixed at 8 s |
|
|
||||||
| Inference speed matters | **SelVA** — 25 steps vs 250 (10× faster) |
|
|
||||||
| Non-commercial research | Either |
|
|
||||||
| Any commercial use | **SelVA only** — AudioX is CC-BY-NC-4.0 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
| Dimension | SelVA (MMAudio) | AudioX-MAF |
|
|
||||||
|---|---|---|
|
|
||||||
| Core paradigm | Flow matching | Diffusion (k-diffusion / DPM++) |
|
|
||||||
| Inference steps | 25 ODE steps (Euler) | 250 diffusion steps (DPM++ 3M SDE) |
|
|
||||||
| Sample rate | 44.1 kHz (large) / 16 kHz (small) | 48 kHz (fixed) |
|
|
||||||
| Generator | MM-DiT, velocity prediction | ContinuousMMDiTTransformer |
|
|
||||||
| Video encoder | Synchformer | Synchformer (AudioX custom re-impl, same concept) |
|
|
||||||
| VAE / codec | DAC (descript-audio-codec) | DAC + AudioCraft options |
|
|
||||||
| Text encoder | T5-large | T5 (configurable small → XXL) |
|
|
||||||
| Video-audio fusion | Cross-attention in MM-DiT | MAF: dual-projection (dim alignment + seq length alignment) |
|
|
||||||
| Output duration | Fixed 8 s | Configurable via `sample_size` (default ~44 s at 48kHz) |
|
|
||||||
| Training data | ~2 M samples (MMAudio paper) | 7 M samples (IF-caps dataset, curated) |
|
|
||||||
| License | MIT | CC-BY-NC-4.0 |
|
|
||||||
|
|
||||||
**MAF (Multimodal Adaptive Fusion):** AudioX's key architectural contribution. Instead of directly
|
|
||||||
concatenating multimodal tokens into the DiT's cross-attention, MAF projects each modality to
|
|
||||||
match the latent's sequence length via a dedicated linear + transposed-conv stack, then applies
|
|
||||||
`MMDitSingleBlock` layers for cross-modal fusion. The paper reports this improves cross-modal
|
|
||||||
alignment particularly for video-to-audio tasks.
|
|
||||||
|
|
||||||
**Flow matching vs diffusion:** Flow matching (SelVA) trains a single velocity field to move
|
|
||||||
directly from noise to data along a straight trajectory — this is why 25 steps suffice. Standard
|
|
||||||
diffusion (AudioX) approximates a longer stochastic path, requiring 250 steps for quality output.
|
|
||||||
This is not a quality difference per se; flow matching is simply more efficient.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Capabilities
|
|
||||||
|
|
||||||
| Task | SelVA | AudioX |
|
|
||||||
|---|---|---|
|
|
||||||
| Video → sound effects | ✓ (primary use case) | ✓ |
|
|
||||||
| Text → sound effects | Partial (T5 conditions quality but not primary) | ✓ (strong benchmark scores) |
|
|
||||||
| Video → music | ✗ | ✓ |
|
|
||||||
| Text → music | ✗ | ✓ |
|
|
||||||
| Audio inpainting | ✗ | ✓ (mask_args parameter) |
|
|
||||||
| Music continuation | ✗ | ✓ (init_audio parameter) |
|
|
||||||
| Variable output duration | ✗ (fixed 8 s) | ✓ |
|
|
||||||
| Multiple input modalities simultaneously | Partial | ✓ (text + video + audio at once) |
|
|
||||||
|
|
||||||
AudioX benchmarks claim superior results on text-to-audio (AudioCaps) and text-to-music
|
|
||||||
(MusicCaps) vs prior models. Video-to-audio comparison against MMAudio specifically is not
|
|
||||||
prominently featured in the paper. Perceptual evaluation confirms this: AudioX does not sound
|
|
||||||
noticeably better than SelVA on video-to-audio tasks. AudioX's advantage is **breadth**
|
|
||||||
(music, inpainting, variable duration), not raw video-to-audio quality.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Integration Cost
|
|
||||||
|
|
||||||
Adding AudioX inference-only nodes to PrismAudio would require:
|
|
||||||
|
|
||||||
### New nodes (3 files)
|
|
||||||
|
|
||||||
```
|
|
||||||
nodes/
|
|
||||||
audiox_model_loader.py AUDIOX_MODEL loader — get_pretrained_model("HKUSTAudio/AudioX-MAF")
|
|
||||||
audiox_sampler.py wraps generate_diffusion_cond(), inputs: model + text + video + audio
|
|
||||||
audiox_feature_extractor.py optional — pre-extract Synchformer sync features (caching)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install git+https://github.com/ZeyueT/AudioX.git
|
|
||||||
```
|
|
||||||
|
|
||||||
New dependencies not currently in PrismAudio:
|
|
||||||
- `pytorch-lightning==2.4.0`
|
|
||||||
- `k-diffusion==0.1.1`
|
|
||||||
- `v-diffusion-pytorch==0.0.2`
|
|
||||||
- `descript-audio-codec==1.0.0` (already used by SelVA — no conflict, same package)
|
|
||||||
- `gradio==4.44.1` (optional — only for the upstream Gradio UI)
|
|
||||||
|
|
||||||
Model weights: `HKUSTAudio/AudioX-MAF` on HuggingFace (~several GB).
|
|
||||||
|
|
||||||
### Inference API surface
|
|
||||||
|
|
||||||
```python
|
|
||||||
from audiox import get_pretrained_model
|
|
||||||
from audiox.inference.generation import generate_diffusion_cond
|
|
||||||
|
|
||||||
model, config = get_pretrained_model("HKUSTAudio/AudioX-MAF")
|
|
||||||
|
|
||||||
output = generate_diffusion_cond(
|
|
||||||
model,
|
|
||||||
steps=250,
|
|
||||||
cfg_scale=6.0,
|
|
||||||
conditioning={
|
|
||||||
"text_prompt": "a dog barking",
|
|
||||||
"video_prompt": {"video": frames_tensor, "sync_features": sync_feat},
|
|
||||||
"seconds_total": 8.0,
|
|
||||||
},
|
|
||||||
sample_size=384000, # 8 s at 48kHz
|
|
||||||
sample_rate=48000,
|
|
||||||
device="cuda",
|
|
||||||
)
|
|
||||||
# output: torch.Tensor (batch, channels, num_samples) float32 [-1, 1]
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## LoRA Training
|
|
||||||
|
|
||||||
Adding AudioX LoRA training to PrismAudio is **significantly harder** than the SelVA trainer:
|
|
||||||
|
|
||||||
| Aspect | SelVA LoRA | AudioX LoRA |
|
|
||||||
|---|---|---|
|
|
||||||
| Loss function | Single MSE velocity loss | Diffusion loss over 250-step schedule |
|
|
||||||
| Training steps needed | ~2000 steps practical | Unknown — likely much more |
|
|
||||||
| Step cost | Fast (1 velocity prediction) | Slow (full diffusion forward pass per step) |
|
|
||||||
| Existing infrastructure | Full trainer + scheduler + experiments | Nothing — would need to build from scratch |
|
|
||||||
| Noise schedule | Trivial (linear interpolation) | Cosine alpha-sigma schedule |
|
|
||||||
| Prior art for LoRA | LoRA on flow matching well-studied | Less explored; closer to Stable Diffusion LoRA |
|
|
||||||
|
|
||||||
**Conclusion:** AudioX LoRA training is feasible (it would follow SD-style LoRA with the DPM++
|
|
||||||
noise schedule) but would be a substantial new project. Not worth building until inference nodes
|
|
||||||
are stable and there is a clear use case that SelVA cannot serve.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
AudioX weights are released under **CC-BY-NC-4.0** (Creative Commons Non-Commercial).
|
|
||||||
|
|
||||||
- Free for personal use, research, and non-commercial projects
|
|
||||||
- **Cannot be used in commercial products or services** without a separate agreement
|
|
||||||
- Attribution required
|
|
||||||
- SelVA/MMAudio: MIT (no restrictions)
|
|
||||||
|
|
||||||
If PrismAudio is ever distributed as part of a commercial tool, AudioX nodes must be clearly
|
|
||||||
opt-in with a license warning, or excluded entirely.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Recommendation
|
|
||||||
|
|
||||||
**Short term:** AudioX is not a replacement for SelVA for the current use case (video → custom
|
|
||||||
sound effects with LoRA fine-tuning). SelVA is faster, has full training infrastructure, and
|
|
||||||
is MIT licensed.
|
|
||||||
|
|
||||||
**When AudioX becomes worth integrating:**
|
|
||||||
- If you need to generate background music synchronized to video
|
|
||||||
- If you need audio inpainting (fill a gap in an existing audio track)
|
|
||||||
- If you need text-to-audio generation without a video input
|
|
||||||
- After verifying the CC-BY-NC-4.0 license is acceptable for your use
|
|
||||||
|
|
||||||
**Estimated integration effort for inference nodes only:** 2–3 days of work (3 new node files,
|
|
||||||
dependency management, testing). No changes to existing SelVA nodes required — they would
|
|
||||||
coexist in the same package.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- Paper: arXiv:2503.10522 — *AudioX: Diffusion Transformer for Anything-to-Audio Generation*
|
|
||||||
- GitHub: https://github.com/ZeyueT/AudioX
|
|
||||||
- Model weights: https://huggingface.co/HKUSTAudio/AudioX-MAF
|
|
||||||
- Demo: https://huggingface.co/spaces/Zeyue7/AudioX
|
|
||||||
- Project page: https://zeyuet.github.io/AudioX/
|
|
||||||
@@ -0,0 +1,194 @@
|
|||||||
|
# ComfyUI-PrismAudio Design Document
|
||||||
|
|
||||||
|
**Date:** 2026-03-27
|
||||||
|
**Status:** Approved
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
ComfyUI-PrismAudio/
|
||||||
|
├── __init__.py # Node registration
|
||||||
|
├── nodes/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── model_loader.py # PrismAudioModelLoader
|
||||||
|
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
|
||||||
|
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
|
||||||
|
│ ├── sampler.py # PrismAudioSampler
|
||||||
|
│ ├── text_only.py # PrismAudioTextOnly
|
||||||
|
│ └── utils.py # Shared helpers
|
||||||
|
├── prismaudio_core/ # Extracted inference code from PrismAudio
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── configs/
|
||||||
|
│ │ └── prismaudio.json
|
||||||
|
│ ├── models/ # DiT, conditioners, autoencoders, etc.
|
||||||
|
│ ├── inference/ # sampling.py, generation.py
|
||||||
|
│ └── factory.py # create_model_from_config
|
||||||
|
├── scripts/
|
||||||
|
│ ├── extract_features.py # Standalone VideoPrism feature extraction
|
||||||
|
│ └── environment.yml # Conda env for extraction (JAX + TF)
|
||||||
|
├── requirements.txt # PyTorch-only deps (no JAX/TF)
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### PrismAudioModelLoader
|
||||||
|
|
||||||
|
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
|
||||||
|
|
||||||
|
| Field | Type | Details |
|
||||||
|
|-------|------|---------|
|
||||||
|
| **Inputs** | | |
|
||||||
|
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
|
||||||
|
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
|
||||||
|
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
|
||||||
|
| **Output** | | |
|
||||||
|
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
|
||||||
|
|
||||||
|
**Token resolution order** (no widget — env/CLI only for security):
|
||||||
|
1. `HF_TOKEN` environment variable
|
||||||
|
2. `huggingface-cli login` cached token
|
||||||
|
3. None — fails on gated models with clear error message linking to license page
|
||||||
|
|
||||||
|
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
|
||||||
|
|
||||||
|
### PrismAudioFeatureLoader
|
||||||
|
|
||||||
|
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
|
||||||
|
|
||||||
|
| Field | Type | Details |
|
||||||
|
|-------|------|---------|
|
||||||
|
| **Inputs** | | |
|
||||||
|
| npz_path | STRING | Path to .npz file |
|
||||||
|
| **Output** | | |
|
||||||
|
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
|
||||||
|
|
||||||
|
### PrismAudioFeatureExtractor
|
||||||
|
|
||||||
|
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
|
||||||
|
|
||||||
|
| Field | Type | Details |
|
||||||
|
|-------|------|---------|
|
||||||
|
| **Inputs** | | |
|
||||||
|
| video | IMAGE | ComfyUI video frames tensor |
|
||||||
|
| caption_cot | STRING | CoT description text |
|
||||||
|
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
|
||||||
|
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
|
||||||
|
| **Output** | | |
|
||||||
|
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
|
||||||
|
|
||||||
|
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
|
||||||
|
|
||||||
|
### PrismAudioSampler
|
||||||
|
|
||||||
|
Main generation node — takes model + features, produces audio.
|
||||||
|
|
||||||
|
| Field | Type | Details |
|
||||||
|
|-------|------|---------|
|
||||||
|
| **Inputs** | | |
|
||||||
|
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||||
|
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
|
||||||
|
| cot_description | STRING | Multiline CoT text |
|
||||||
|
| duration | FLOAT | 1.0-30.0, defaults to video length |
|
||||||
|
| steps | INT | 1-100, default 24 |
|
||||||
|
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||||
|
| seed | INT | Controls noise generation |
|
||||||
|
| **Output** | | |
|
||||||
|
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||||
|
|
||||||
|
**Pipeline:**
|
||||||
|
1. Encode CoT text via T5-Gemma -> text_features
|
||||||
|
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
|
||||||
|
3. Compute latent_seq_len = round(44100 / 2048 * duration)
|
||||||
|
4. Generate noise [1, 64, latent_seq_len] from seed
|
||||||
|
5. Discrete Euler sampling (rectified flow) with CFG
|
||||||
|
6. VAE decode -> stereo waveform at 44100 Hz
|
||||||
|
7. Normalize to [-1, 1], return as AUDIO
|
||||||
|
|
||||||
|
### PrismAudioTextOnly
|
||||||
|
|
||||||
|
Text-to-audio without video input.
|
||||||
|
|
||||||
|
| Field | Type | Details |
|
||||||
|
|-------|------|---------|
|
||||||
|
| **Inputs** | | |
|
||||||
|
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||||
|
| text_prompt | STRING | Text description |
|
||||||
|
| duration | FLOAT | 1.0-30.0 |
|
||||||
|
| steps | INT | 1-100, default 24 |
|
||||||
|
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||||
|
| seed | INT | Controls noise generation |
|
||||||
|
| **Output** | | |
|
||||||
|
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||||
|
|
||||||
|
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
|
||||||
|
|
||||||
|
## VRAM Management
|
||||||
|
|
||||||
|
Adaptive strategy using `comfy.model_management`:
|
||||||
|
|
||||||
|
| Available VRAM | Behavior |
|
||||||
|
|---|---|
|
||||||
|
| 24GB+ | Keep diffusion + VAE in VRAM |
|
||||||
|
| 12-24GB | Sequential offload between stages |
|
||||||
|
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
|
||||||
|
| <8GB | Warn user, attempt with aggressive offload + fp16 |
|
||||||
|
|
||||||
|
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
|
||||||
|
|
||||||
|
## Feature Extraction Paths
|
||||||
|
|
||||||
|
### Path 1: Pre-computed .npz (FeatureLoader)
|
||||||
|
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
|
||||||
|
|
||||||
|
### Path 2: Subprocess bridge (FeatureExtractor)
|
||||||
|
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
|
||||||
|
|
||||||
|
### Path 3: Text-only (TextOnly node)
|
||||||
|
No video features needed. T5-Gemma text encoding only (PyTorch-native).
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### ComfyUI environment (`requirements.txt`)
|
||||||
|
```
|
||||||
|
einops>=0.7.0
|
||||||
|
safetensors
|
||||||
|
huggingface_hub
|
||||||
|
transformers>=4.52.3
|
||||||
|
k-diffusion>=0.1.1
|
||||||
|
```
|
||||||
|
|
||||||
|
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
|
||||||
|
|
||||||
|
### Extraction environment (`scripts/environment.yml`)
|
||||||
|
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
|
||||||
|
|
||||||
|
## Model Files
|
||||||
|
|
||||||
|
Stored in `ComfyUI/models/prismaudio/`:
|
||||||
|
|
||||||
|
| File | Size | Source |
|
||||||
|
|------|------|--------|
|
||||||
|
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
|
||||||
|
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
|
||||||
|
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
|
||||||
|
|
||||||
|
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
|
||||||
|
|
||||||
|
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
|
||||||
|
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
|
||||||
|
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
|
||||||
|
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
|
||||||
|
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
|
||||||
|
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
|||||||
|
# SelVA Integration Design
|
||||||
|
|
||||||
|
**Date:** 2026-04-04
|
||||||
|
**Branch:** feature/selva-integration (new from master)
|
||||||
|
**Status:** Approved, ready for implementation
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
PrismAudio's sync conditioning is text-agnostic: Synchformer extracts features from
|
||||||
|
all visual motion equally. In multi-source videos (person walking near a car), the DiT
|
||||||
|
receives unfocused sync guidance and struggles to match audio events to the correct
|
||||||
|
visual source.
|
||||||
|
|
||||||
|
SelVA (CVPR 2026, arXiv:2512.02650) solves this with TextSynchformer — text conditioning
|
||||||
|
is injected inside the Synchformer encoder via cross-attention, so sync features only
|
||||||
|
encode motion relevant to the requested sound. This is the core architectural improvement
|
||||||
|
needed for reliable V2A sync.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### New directory layout
|
||||||
|
|
||||||
|
```
|
||||||
|
selva_core/ ← vendored SelVA source (model + ext + utils)
|
||||||
|
nodes/
|
||||||
|
selva_model_loader.py
|
||||||
|
selva_feature_extractor.py
|
||||||
|
selva_sampler.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### New custom types
|
||||||
|
|
||||||
|
- `SELVA_MODEL` — `{generator, video_enc, feature_utils, variant, strategy, dtype}`
|
||||||
|
- `SELVA_FEATURES` — `{clip_features, sync_features, duration}`
|
||||||
|
|
||||||
|
### No subprocess
|
||||||
|
|
||||||
|
SelVA is pure PyTorch. Feature extraction runs inline in ComfyUI — no managed venv,
|
||||||
|
no JAX/TF, no pip install on first run.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
Zero new pip packages. ComfyUI already ships:
|
||||||
|
- `open_clip_torch` (CLIP ViT-H-14-384, auto-downloads via `hf-hub:` on first use)
|
||||||
|
- `transformers` (flan-t5-base, auto-downloads from HuggingFace on first use)
|
||||||
|
- `torch`, `torchaudio`, `einops`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Nodes
|
||||||
|
|
||||||
|
### `SelvaModelLoader` → `SELVA_MODEL`
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| variant | dropdown | medium_44k | small_16k / small_44k / medium_44k / large_44k |
|
||||||
|
| precision | dropdown | bf16 | bf16 / fp16 / fp32 |
|
||||||
|
| offload_strategy | dropdown | auto | auto / keep_in_vram / offload_to_cpu |
|
||||||
|
|
||||||
|
Resolves weights from `models/selva/`. Raises descriptive errors with download
|
||||||
|
instructions if files are missing.
|
||||||
|
|
||||||
|
### `SelvaFeatureExtractor` → `SELVA_FEATURES`, `FLOAT` (fps)
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| video | IMAGE | — | ComfyUI video tensor [T,H,W,C] |
|
||||||
|
| prompt | STRING | — | Used by TextSynchformer to select relevant motion |
|
||||||
|
| video_info | VHS_VIDEOINFO | opt | Auto-sets fps when connected |
|
||||||
|
| fps | FLOAT | 30.0 | Fallback fps if video_info not connected |
|
||||||
|
| cache_dir | STRING | "" | Empty = system temp dir |
|
||||||
|
|
||||||
|
Feature extraction steps (all inline, no subprocess):
|
||||||
|
1. Resize frames to 384×384 → CLIP video features `[B, T, 1024]`
|
||||||
|
2. Resize frames to 224×224 + encode prompt with flan-T5 → TextSynchformer → text-conditioned sync features `[B, T, 768]`
|
||||||
|
3. Save to `.npz` cache keyed by hash(frames[:1MB] + prompt + fps)
|
||||||
|
|
||||||
|
### `SelvaSampler` → `AUDIO`
|
||||||
|
|
||||||
|
| Input | Type | Default | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| model | SELVA_MODEL | — | |
|
||||||
|
| features | SELVA_FEATURES | — | |
|
||||||
|
| prompt | STRING | — | Should match extractor prompt; drives CLIP text guidance |
|
||||||
|
| negative_prompt | STRING | "" | Steers away from unwanted sounds |
|
||||||
|
| duration | FLOAT | 0.0 | 0 = auto from features duration |
|
||||||
|
| steps | INT | 25 | Euler steps (25 is SelVA default, fast) |
|
||||||
|
| cfg_strength | FLOAT | 4.5 | CFG scale (SelVA default) |
|
||||||
|
| seed | INT | 0 | |
|
||||||
|
|
||||||
|
Generation steps:
|
||||||
|
1. Encode prompt → CLIP text features (for MMAudio)
|
||||||
|
2. Encode negative prompt → empty conditions for CFG
|
||||||
|
3. `net_generator.preprocess_conditions(clip_f, sync_f, text_clip)`
|
||||||
|
4. Flow matching Euler ODE (`num_steps` iterations) with CFG
|
||||||
|
5. `feature_utils.decode(latent)` → mel spectrogram
|
||||||
|
6. `feature_utils.vocode(spec)` → waveform (BigVGAN for 16k, direct for 44k)
|
||||||
|
|
||||||
|
**Note on dual prompt:** The extractor prompt is baked into sync_features via
|
||||||
|
TextSynchformer at extraction time. The sampler prompt drives CLIP text conditioning
|
||||||
|
at generation time. They should match — a tooltip explains this.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
[VHS LoadVideo] ──► [SelvaFeatureExtractor]
|
||||||
|
│ prompt: "dog barking"
|
||||||
|
│ video_info: (fps auto)
|
||||||
|
▼
|
||||||
|
SELVA_FEATURES
|
||||||
|
{clip_features [B,T,1024],
|
||||||
|
sync_features [B,T,768], ← text-conditioned
|
||||||
|
duration: 8.2s}
|
||||||
|
│
|
||||||
|
[SelvaModelLoader] ──► [SelvaSampler]
|
||||||
|
variant: medium_44k │ prompt: "dog barking"
|
||||||
|
precision: bf16 │ negative: "wind noise"
|
||||||
|
│ cfg_strength: 4.5, steps: 25
|
||||||
|
▼
|
||||||
|
AUDIO (44.1kHz or 16kHz)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Weights
|
||||||
|
|
||||||
|
Location: `models/selva/`
|
||||||
|
|
||||||
|
```
|
||||||
|
video_enc_sup_5.pth ← TextSynch, shared across all variants
|
||||||
|
generator_small_16k_sup_5.pth
|
||||||
|
generator_small_44k_sup_5.pth
|
||||||
|
generator_medium_44k_sup_5.pth
|
||||||
|
generator_large_44k_sup_5.pth
|
||||||
|
ext/
|
||||||
|
v1-16.pth ← VAE for 16k variants
|
||||||
|
v1-44.pth ← VAE for 44k variants
|
||||||
|
best_netG.pt ← BigVGAN vocoder (16k only)
|
||||||
|
```
|
||||||
|
|
||||||
|
`synchformer_state_dict.pth` is reused from `models/prismaudio/` — no duplicate.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## selva_core vendoring
|
||||||
|
|
||||||
|
Copy from `jnwnlee/selva` (pinned to a specific commit for stability):
|
||||||
|
- `selva_core/model/` — MMAudio, TextSynch, transformer layers, embeddings, flow matching
|
||||||
|
- `selva_core/ext/` — autoencoder, BigVGAN, synchformer, rotary embeddings, mel converters
|
||||||
|
- `selva_core/utils/` — transforms, generate() helper
|
||||||
|
|
||||||
|
Rename all internal imports from `selva.*` → `selva_core.*`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## What stays the same
|
||||||
|
|
||||||
|
- All PrismAudio nodes unchanged
|
||||||
|
- `models/prismaudio/` unchanged
|
||||||
|
- Synchformer checkpoint shared (not duplicated)
|
||||||
|
- Branch: new `feature/selva-integration` off master (LoRA work stays separate)
|
||||||
@@ -0,0 +1,738 @@
|
|||||||
|
# SelVA Integration Implementation Plan
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** Add three new ComfyUI nodes (SelvaModelLoader, SelvaFeatureExtractor, SelvaSampler) that run SelVA's text-conditioned V2A pipeline inline — no subprocess, no JAX, pure PyTorch.
|
||||||
|
|
||||||
|
**Architecture:** Vendor SelVA source into `selva_core/`, implement three nodes that mirror the PrismAudio pattern. `SelvaFeatureExtractor` takes `SELVA_MODEL` (needs TextSynchformer + CLIP/T5 from FeaturesUtils). `SelvaSampler` runs flow matching ODE with CFG and negative prompts.
|
||||||
|
|
||||||
|
**Tech Stack:** PyTorch, open_clip (already in ComfyUI), transformers (already in ComfyUI), torchaudio, einops, torchvision
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Design reference
|
||||||
|
|
||||||
|
`docs/plans/2026-04-04-selva-integration-design.md`
|
||||||
|
|
||||||
|
**Key facts from SelVA source:**
|
||||||
|
- CLIP input: `[B, T, C, 384, 384]` float32 `[0,1]` — normalization applied inside FeaturesUtils
|
||||||
|
- Sync input: `[B, T, C, 224, 224]` float32 `[-1,1]` — normalize with `mean=std=[0.5,0.5,0.5]` before passing
|
||||||
|
- CLIP frame rate: 8fps, Sync frame rate: 25fps
|
||||||
|
- CONFIG_16K: latent=250, clip=64, sync=192 at 8s
|
||||||
|
- CONFIG_44K: latent=345, clip=64, sync=192 at 8s
|
||||||
|
- Sync segments: 16-frame windows, 8-frame stride (overlapping, unlike PrismAudio's 8-frame non-overlapping)
|
||||||
|
- `net_generator.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)` must be called before each generation when duration ≠ 8s
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 1: Create branch and vendor selva_core
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `selva_core/` (full directory tree)
|
||||||
|
|
||||||
|
**Step 1: Create new branch off master (not off feature/lora-trainer)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout master
|
||||||
|
git checkout -b feature/selva-integration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Clone SelVA and copy source**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/jnwnlee/selva.git /tmp/selva_src
|
||||||
|
cp -r /tmp/selva_src/selva /media/p5/Comfyui-Prismaudio/selva_core
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Rename all internal imports**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio/selva_core
|
||||||
|
find . -name "*.py" -exec sed -i \
|
||||||
|
's/from selva\./from selva_core./g;
|
||||||
|
s/import selva\./import selva_core./g' {} \;
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: Record the pinned commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /tmp/selva_src && git rev-parse HEAD
|
||||||
|
# Paste the hash into a comment at the top of selva_core/__init__.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Edit `selva_core/__init__.py` to add at the top:
|
||||||
|
```python
|
||||||
|
# Vendored from https://github.com/jnwnlee/selva
|
||||||
|
# Pinned commit: <PASTE_HASH_HERE>
|
||||||
|
# Imports rewritten from selva.* → selva_core.*
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: Verify imports work**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
from selva_core.model.networks_generator import MMAudio, get_my_mmaudio
|
||||||
|
from selva_core.model.networks_video_enc import TextSynch, get_my_textsynch
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
||||||
|
print('selva_core imports OK')
|
||||||
|
print(f'CONFIG_16K: latent={CONFIG_16K.latent_seq_len} clip={CONFIG_16K.clip_seq_len} sync={CONFIG_16K.sync_seq_len}')
|
||||||
|
print(f'CONFIG_44K: latent={CONFIG_44K.latent_seq_len} clip={CONFIG_44K.clip_seq_len} sync={CONFIG_44K.sync_seq_len}')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected:
|
||||||
|
```
|
||||||
|
selva_core imports OK
|
||||||
|
CONFIG_16K: latent=250 clip=64 sync=192
|
||||||
|
CONFIG_44K: latent=345 clip=64 sync=192
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add selva_core/
|
||||||
|
git commit -m "chore: vendor selva_core from jnwnlee/selva@<HASH>
|
||||||
|
|
||||||
|
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
|
||||||
|
Imports rewritten from selva.* to selva_core.*. No training code included."
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 2: Implement SelvaModelLoader
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_model_loader.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_model_loader.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
|
||||||
|
|
||||||
|
# Variant → (generator filename, mode, has_bigvgan)
|
||||||
|
_VARIANTS = {
|
||||||
|
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||||
|
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||||
|
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||||
|
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
_SELVA_DIR = os.path.join(folder_paths.models_dir, "selva")
|
||||||
|
|
||||||
|
|
||||||
|
def _selva_path(*parts):
|
||||||
|
return os.path.join(_SELVA_DIR, *parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _require(path, hint):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[SelVA] Missing: {path}\n{hint}"
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaModelLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"variant": (list(_VARIANTS.keys()),),
|
||||||
|
"precision": (["bf16", "fp16", "fp32"],),
|
||||||
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_model(self, variant, precision, offload_strategy):
|
||||||
|
from selva_core.model.networks_generator import get_my_mmaudio
|
||||||
|
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||||
|
|
||||||
|
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||||
|
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
strategy = determine_offload_strategy(offload_strategy)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# Resolve weight paths
|
||||||
|
video_enc_path = _require(
|
||||||
|
_selva_path("video_enc_sup_5.pth"),
|
||||||
|
"Download from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||||
|
)
|
||||||
|
gen_path = _require(
|
||||||
|
_selva_path(gen_filename),
|
||||||
|
f"Download {gen_filename} from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||||
|
)
|
||||||
|
vae_path = _require(
|
||||||
|
_selva_path("ext", f"v1-{mode}.pth"),
|
||||||
|
f"Download v1-{mode}.pth from MMAudio/SelVA release and place in models/selva/ext/"
|
||||||
|
)
|
||||||
|
synch_path = _require(
|
||||||
|
os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth"),
|
||||||
|
"Synchformer checkpoint missing from models/prismaudio/ — download from FunAudioLLM/PrismAudio"
|
||||||
|
)
|
||||||
|
bigvgan_path = None
|
||||||
|
if has_bigvgan:
|
||||||
|
bigvgan_path = _require(
|
||||||
|
_selva_path("ext", "best_netG.pt"),
|
||||||
|
"Download best_netG.pt (BigVGAN 16k vocoder) from MMAudio release and place in models/selva/ext/"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||||
|
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||||
|
net_video_enc.load_weights(
|
||||||
|
torch.load(video_enc_path, map_location="cpu", weights_only=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
||||||
|
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||||
|
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
||||||
|
net_generator.load_weights(
|
||||||
|
torch.load(gen_path, map_location="cpu", weights_only=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
||||||
|
feature_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=vae_path,
|
||||||
|
synchformer_ckpt=synch_path,
|
||||||
|
enable_conditions=True,
|
||||||
|
mode=mode,
|
||||||
|
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||||
|
).to(device, dtype).eval()
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
|
||||||
|
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
||||||
|
|
||||||
|
return ({
|
||||||
|
"generator": net_generator,
|
||||||
|
"video_enc": net_video_enc,
|
||||||
|
"feature_utils": feature_utils,
|
||||||
|
"variant": variant,
|
||||||
|
"mode": mode,
|
||||||
|
"strategy": strategy,
|
||||||
|
"dtype": dtype,
|
||||||
|
"seq_cfg": seq_cfg,
|
||||||
|
},)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
In the `NODE_CLASS_MAPPINGS` dict, add:
|
||||||
|
```python
|
||||||
|
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_model_loader import SelvaModelLoader
|
||||||
|
print('inputs:', list(SelvaModelLoader.INPUT_TYPES()['required'].keys()))
|
||||||
|
print('outputs:', SelvaModelLoader.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `inputs: ['variant', 'precision', 'offload_strategy']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_model_loader.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaModelLoader node — loads TextSynch + MMAudio + FeaturesUtils"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 3: Implement SelvaFeatureExtractor
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_feature_extractor.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_feature_extractor.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
|
|
||||||
|
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_CLIP_FPS = 8
|
||||||
|
_SYNC_FPS = 25
|
||||||
|
|
||||||
|
# Sync normalization: [-1, 1] (from selva/utils/eval_utils.py load_video)
|
||||||
|
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_frames(video, source_fps, target_fps, duration):
|
||||||
|
"""Sample frames from [T,H,W,C] float32 [0,1] at target_fps."""
|
||||||
|
T = video.shape[0]
|
||||||
|
n_out = max(1, int(duration * target_fps))
|
||||||
|
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||||
|
return video[indices] # [N, H, W, C]
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_frames(frames, size):
|
||||||
|
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
||||||
|
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
||||||
|
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False)
|
||||||
|
return x.clamp(0, 1) # [N, C, H, W] float32
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_inputs(video_tensor, prompt, fps, variant):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
|
||||||
|
h.update(prompt.encode())
|
||||||
|
h.update(str(fps).encode())
|
||||||
|
h.update(variant.encode())
|
||||||
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaFeatureExtractor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"video": ("IMAGE",),
|
||||||
|
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Text prompt used by TextSynchformer to focus sync features on the relevant sound source. Should match the prompt used in SelvaSampler."}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info to auto-set fps."}),
|
||||||
|
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Override duration in seconds. 0 = infer from video length and fps."}),
|
||||||
|
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory for cached .npz features. Empty = temp dir."}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT")
|
||||||
|
RETURN_NAMES = ("features", "fps")
|
||||||
|
FUNCTION = "extract_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||||
|
duration=0.0, cache_dir=""):
|
||||||
|
if video_info is not None:
|
||||||
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
T = video.shape[0]
|
||||||
|
if duration <= 0:
|
||||||
|
duration = T / fps
|
||||||
|
duration = min(duration, T / fps) # clamp to actual video length
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
cache_key = _hash_inputs(video, prompt, fps, model["variant"])
|
||||||
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
|
||||||
|
if os.path.exists(cached_path):
|
||||||
|
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||||
|
return (_load_cached(cached_path), float(fps))
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
net_video_enc = model["video_enc"]
|
||||||
|
|
||||||
|
# Move feature models to device
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(device)
|
||||||
|
net_video_enc.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# --- CLIP frames: 384×384, [0,1], 8fps ---
|
||||||
|
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||||
|
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||||
|
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||||
|
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps", flush=True)
|
||||||
|
|
||||||
|
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||||
|
|
||||||
|
# --- Sync frames: 224×224, [-1,1], 25fps ---
|
||||||
|
n_sync = max(16, int(duration * _SYNC_FPS)) # minimum 16 for segmentation
|
||||||
|
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration)
|
||||||
|
if sync_frames.shape[0] < 16:
|
||||||
|
# Pad by repeating last frame to reach minimum 16
|
||||||
|
pad = 16 - sync_frames.shape[0]
|
||||||
|
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||||
|
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||||
|
std = _SYNC_STD.to(sync_frames.device)
|
||||||
|
sync_frames = (sync_frames - mean) / std
|
||||||
|
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||||
|
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps", flush=True)
|
||||||
|
|
||||||
|
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||||
|
text_f_t5, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, 768], [1, L]
|
||||||
|
text_f_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_f_t5, text_mask)
|
||||||
|
sync_features = net_video_enc.encode_video_with_sync(
|
||||||
|
sync_input, text_f=text_f_t5, text_mask=text_mask
|
||||||
|
) # [1, T_sync, 768]
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||||
|
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||||
|
|
||||||
|
# Offload back if needed
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Save cache
|
||||||
|
np.savez(
|
||||||
|
cached_path,
|
||||||
|
clip_features=clip_features.cpu().float().numpy(),
|
||||||
|
sync_features=sync_features.cpu().float().numpy(),
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"clip_features": clip_features.cpu(),
|
||||||
|
"sync_features": sync_features.cpu(),
|
||||||
|
"duration": duration,
|
||||||
|
}
|
||||||
|
return (features, float(fps))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cached(path):
|
||||||
|
data = np.load(path, allow_pickle=False)
|
||||||
|
return {
|
||||||
|
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||||
|
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||||
|
"duration": float(data["duration"]),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||||
|
inputs = SelvaFeatureExtractor.INPUT_TYPES()
|
||||||
|
print('required:', list(inputs['required'].keys()))
|
||||||
|
print('optional:', list(inputs['optional'].keys()))
|
||||||
|
print('outputs:', SelvaFeatureExtractor.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `required: ['model', 'video', 'prompt']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_feature_extractor.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaFeatureExtractor — inline CLIP + TextSynchformer feature extraction"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 4: Implement SelvaSampler
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `nodes/selva_sampler.py`
|
||||||
|
- Modify: `nodes/__init__.py`
|
||||||
|
|
||||||
|
**Step 1: Create `nodes/selva_sampler.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seq_cfg(duration, mode):
|
||||||
|
"""Compute sequence lengths for a given duration and mode."""
|
||||||
|
from selva_core.model.sequence_config import SequenceConfig
|
||||||
|
if mode == "16k":
|
||||||
|
return SequenceConfig(duration=duration, sampling_rate=16000, spectrogram_frame_rate=256)
|
||||||
|
else:
|
||||||
|
return SequenceConfig(duration=duration, sampling_rate=44100, spectrogram_frame_rate=512)
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaSampler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"features": ("SELVA_FEATURES",),
|
||||||
|
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Should match the prompt used in SelvaFeatureExtractor."}),
|
||||||
|
"negative_prompt": ("STRING", {"default": "", "multiline": True,
|
||||||
|
"tooltip": "Sounds to steer away from, e.g. 'wind noise, background music'."}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Audio duration in seconds. 0 = use duration from features."}),
|
||||||
|
"steps": ("INT", {"default": 25, "min": 1, "max": 200}),
|
||||||
|
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
net_generator = model["generator"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
mode = model["mode"]
|
||||||
|
|
||||||
|
# Resolve duration
|
||||||
|
if duration <= 0:
|
||||||
|
if "duration" not in features:
|
||||||
|
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
||||||
|
duration = features["duration"]
|
||||||
|
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
seq_cfg = _make_seq_cfg(duration, mode)
|
||||||
|
sample_rate = seq_cfg.sampling_rate
|
||||||
|
|
||||||
|
# Move models to device
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(device)
|
||||||
|
feature_utils.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||||
|
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||||
|
print(f"[SelVA] seq_cfg: latent={seq_cfg.latent_seq_len} clip={seq_cfg.clip_seq_len} sync={seq_cfg.sync_seq_len}", flush=True)
|
||||||
|
|
||||||
|
# Update model sequence lengths for this duration
|
||||||
|
net_generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=seq_cfg.clip_seq_len,
|
||||||
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode text
|
||||||
|
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||||
|
|
||||||
|
# Build empty (negative) conditions for CFG
|
||||||
|
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||||
|
if negative_prompt.strip() else None
|
||||||
|
|
||||||
|
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
empty_conditions = net_generator.get_empty_conditions(
|
||||||
|
bs=1, negative_text_features=neg_text_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample initial noise
|
||||||
|
rng = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
x0 = torch.randn(
|
||||||
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
|
device=device, dtype=dtype, generator=rng
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flow matching ODE (Euler)
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
_step_count = [0]
|
||||||
|
orig_to_data = fm.to_data
|
||||||
|
|
||||||
|
def tracked_to_data(fn, x0_):
|
||||||
|
# ProgressBar update via step counting in ode_wrapper
|
||||||
|
return orig_to_data(fn, x0_)
|
||||||
|
|
||||||
|
# Wrap ODE to update progress bar
|
||||||
|
def ode_wrapper_tracked(t, x):
|
||||||
|
_step_count[0] += 1
|
||||||
|
pbar.update(1)
|
||||||
|
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||||
|
|
||||||
|
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||||
|
|
||||||
|
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||||
|
|
||||||
|
# Decode: latent → mel → audio
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x1_unnorm = net_generator.unnormalize(x1)
|
||||||
|
spec = feature_utils.decode(x1_unnorm)
|
||||||
|
audio = feature_utils.vocode(spec) # [1, samples] or [1, 1, samples]
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Normalise to [-1, 1]
|
||||||
|
audio = audio.float()
|
||||||
|
if audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(1) # [1, 1, samples]
|
||||||
|
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||||
|
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||||
|
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||||
|
|
||||||
|
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Register in `nodes/__init__.py`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Verify node registers**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_sampler import SelvaSampler
|
||||||
|
inputs = SelvaSampler.INPUT_TYPES()
|
||||||
|
print('inputs:', list(inputs['required'].keys()))
|
||||||
|
print('outputs:', SelvaSampler.RETURN_TYPES)
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `inputs: ['model', 'features', 'prompt', 'negative_prompt', 'duration', 'steps', 'cfg_strength', 'seed']`
|
||||||
|
|
||||||
|
**Step 4: Commit**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add nodes/selva_sampler.py nodes/__init__.py
|
||||||
|
git commit -m "feat: SelvaSampler — flow matching ODE with CFG + negative prompts"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 5: Create example workflow and push
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `workflows/selva_video_to_audio.json`
|
||||||
|
|
||||||
|
**Step 1: Create workflow JSON**
|
||||||
|
|
||||||
|
Create `workflows/selva_video_to_audio.json` with this node graph:
|
||||||
|
- LoadVideo (VHS) → IMAGE + VHS_VIDEOINFO
|
||||||
|
- SelvaModelLoader → SELVA_MODEL
|
||||||
|
- SelvaFeatureExtractor (takes IMAGE + VHS_VIDEOINFO + SELVA_MODEL, prompt) → SELVA_FEATURES
|
||||||
|
- SelvaSampler (takes SELVA_MODEL + SELVA_FEATURES, prompt, negative_prompt) → AUDIO
|
||||||
|
- PreviewAudio (takes AUDIO)
|
||||||
|
|
||||||
|
Set defaults: variant=medium_44k, precision=bf16, steps=25, cfg_strength=4.5, duration=0.
|
||||||
|
|
||||||
|
**Step 2: Push branch**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git push -u origin feature/selva-integration
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 6: Smoke test
|
||||||
|
|
||||||
|
**Step 1: Check all three nodes are importable from ComfyUI's perspective**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /media/p5/Comfyui-Prismaudio
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
import nodes
|
||||||
|
m = nodes.NODE_CLASS_MAPPINGS
|
||||||
|
print('SelVA nodes:', [k for k in m if 'Selva' in k])
|
||||||
|
assert 'SelvaModelLoader' in m
|
||||||
|
assert 'SelvaFeatureExtractor' in m
|
||||||
|
assert 'SelvaSampler' in m
|
||||||
|
print('All SelVA nodes registered OK')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: Verify no import errors in full node load**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -c "
|
||||||
|
import sys; sys.path.insert(0, '.')
|
||||||
|
from nodes.selva_model_loader import SelvaModelLoader
|
||||||
|
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||||
|
from nodes.selva_sampler import SelvaSampler
|
||||||
|
print('All imports clean')
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: Final commit with any fixes**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git add -A
|
||||||
|
git commit -m "fix: selva integration smoke test fixes (if any)"
|
||||||
|
git push
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The `FeaturesUtils.train()` is overridden to always call `super().train(False)` — SelVA models are always in eval mode
|
||||||
|
- `net_generator.update_seq_lengths` recalculates rotary position embeddings; call it before every generation when duration may vary
|
||||||
|
- ProgressBar tracking: `FlowMatching.to_data` calls `fn(t, x)` for each Euler step; wrapping `ode_wrapper` with a counter gives accurate progress
|
||||||
|
- The `feature_utils.vocode` returns audio at 16kHz for small_16k (uses BigVGAN) and 44.1kHz for 44k variants (uses VAE mel decoder directly)
|
||||||
|
- If `encode_text_t5` or `encode_text_clip` fail with missing model errors on first run, it's HuggingFace downloading `flan-t5-base` and `apple/DFN5B-CLIP-ViT-H-14-384` — this is expected and takes a few minutes once
|
||||||
@@ -1,606 +0,0 @@
|
|||||||
# Audio Dataset Pipeline Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** Add 5 chainable ComfyUI nodes for in-memory audio dataset preprocessing: load → resample → LUFS normalize → inspect/filter → extract single item.
|
|
||||||
|
|
||||||
**Architecture:** Single new file `nodes/selva_dataset_pipeline.py` defines a custom `AUDIO_DATASET` type (list of dicts) and all 5 node classes. Nodes are stateless transforms — each takes `AUDIO_DATASET` and returns `AUDIO_DATASET`. No disk I/O except in the Loader. Register all nodes in `nodes/__init__.py`.
|
|
||||||
|
|
||||||
**Tech Stack:** `pyloudnorm` (BS.1770-4 LUFS), `soxr` (VHQ resampling), `torchaudio`, `torch`. Both confirmed present in the ComfyUI environment at `/media/p5/miniforge3/envs/latestcomfyui`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## The `AUDIO_DATASET` type
|
|
||||||
|
|
||||||
Used as the ComfyUI type string `"AUDIO_DATASET"`. At runtime it is a Python list of dicts:
|
|
||||||
|
|
||||||
```python
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"waveform": torch.Tensor, # shape [1, C, L], float32, range [-1, 1]
|
|
||||||
"sample_rate": int,
|
|
||||||
"name": str, # original filename stem, for reporting
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 1: Create the file skeleton and AUDIO_DATASET constant
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Write the file with imports and type constant only**
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
|
|
||||||
|
|
||||||
Typical chain:
|
|
||||||
SelvaDatasetLoader
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetResampler (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetLUFSNormalizer (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetInspector (optional)
|
|
||||||
↓ AUDIO_DATASET + STRING report
|
|
||||||
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
# ComfyUI custom type name — passed between all dataset pipeline nodes
|
|
||||||
AUDIO_DATASET = "AUDIO_DATASET"
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify import works (no test framework needed — just a quick smoke check)**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /media/p5/Comfyui-Prismaudio
|
|
||||||
python3 -c "from nodes.selva_dataset_pipeline import AUDIO_DATASET; print(AUDIO_DATASET)"
|
|
||||||
```
|
|
||||||
Expected output: `AUDIO_DATASET`
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add audio dataset pipeline skeleton"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 2: SelvaDatasetLoader
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the Loader class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetLoader:
|
|
||||||
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
|
|
||||||
|
|
||||||
def load(self, folder: str):
|
|
||||||
folder = Path(folder.strip())
|
|
||||||
if not folder.exists():
|
|
||||||
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
|
|
||||||
|
|
||||||
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
|
|
||||||
if not files:
|
|
||||||
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
|
|
||||||
|
|
||||||
dataset = []
|
|
||||||
for f in sorted(files):
|
|
||||||
try:
|
|
||||||
wav, sr = torchaudio.load(str(f)) # [C, L]
|
|
||||||
wav = wav.unsqueeze(0).float() # [1, C, L]
|
|
||||||
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
|
|
||||||
return (dataset,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader
|
|
||||||
node = SelvaDatasetLoader()
|
|
||||||
ds, = node.load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
print(len(ds), 'clips', ds[0]['name'], ds[0]['waveform'].shape, ds[0]['sample_rate'])
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: prints clip count, first clip name, shape like `torch.Size([1, 2, 352800])`, sample rate.
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetLoader node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 3: SelvaDatasetResampler
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the Resampler class**
|
|
||||||
|
|
||||||
Uses `soxr` directly for VHQ quality. `soxr.resample` operates on numpy arrays, shape `[L, C]` (time-first).
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetResampler:
|
|
||||||
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_sr": ("INT", {
|
|
||||||
"default": 44100, "min": 8000, "max": 192000,
|
|
||||||
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "resample"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
|
|
||||||
|
|
||||||
def resample(self, dataset, target_sr: int):
|
|
||||||
import soxr
|
|
||||||
|
|
||||||
out = []
|
|
||||||
changed = 0
|
|
||||||
for item in dataset:
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
if sr == target_sr:
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
# soxr expects [L, C] (time-first), float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
|
|
||||||
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
|
|
||||||
changed += 1
|
|
||||||
|
|
||||||
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
|
|
||||||
return (out,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetResampler
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
ds2, = SelvaDatasetResampler().resample(ds, 44100)
|
|
||||||
print('ok', ds2[0]['sample_rate'], ds2[0]['waveform'].shape)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetResampler node (soxr VHQ)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 4: SelvaDatasetLUFSNormalizer
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the LUFS normalizer class**
|
|
||||||
|
|
||||||
`pyloudnorm.Meter` requires numpy float64 array shape `[L]` (mono) or `[L, C]` (multichannel, channels last). True peak limit applied after gain.
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetLUFSNormalizer:
|
|
||||||
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
|
|
||||||
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
out = []
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
try:
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
except Exception:
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_norm = wav * gain_linear
|
|
||||||
|
|
||||||
# True peak limit
|
|
||||||
peak = wav_norm.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_norm = wav_norm * (tp_linear / peak)
|
|
||||||
|
|
||||||
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
|
|
||||||
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetLUFSNormalizer
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
ds2, = SelvaDatasetLUFSNormalizer().normalize(ds, -23.0, -1.0)
|
|
||||||
print('ok', ds2[0]['name'], ds2[0]['waveform'].abs().max().item())
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: peak ≤ ~0.89 (≈ -1 dBTP).
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetLUFSNormalizer node (pyloudnorm BS.1770-4)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 5: SelvaDatasetInspector
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add helper functions for artifact detection**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
|
|
||||||
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
|
|
||||||
|
|
||||||
Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT.
|
|
||||||
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
|
|
||||||
"""
|
|
||||||
if sr < 32000:
|
|
||||||
return False # can't assess HF at low sample rates
|
|
||||||
|
|
||||||
n_fft = 2048
|
|
||||||
hop = 512
|
|
||||||
window = torch.hann_window(n_fft)
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
|
|
||||||
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
|
|
||||||
|
|
||||||
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1)
|
|
||||||
band_lo = (freqs >= 1000) & (freqs < 5000)
|
|
||||||
band_hi = (freqs >= 15000) & (freqs < 20000)
|
|
||||||
|
|
||||||
if band_hi.sum() == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
|
|
||||||
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
|
|
||||||
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
|
|
||||||
return ratio_db > 40.0
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_snr(wav: torch.Tensor) -> float:
|
|
||||||
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
frames = mono.unfold(0, 2048, 512) # [N, 2048]
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt() # [N]
|
|
||||||
p95 = torch.quantile(rms, 0.95).item()
|
|
||||||
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
|
|
||||||
return 20.0 * np.log10(p95 / p05 + 1e-8)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Add the Inspector class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetInspector:
|
|
||||||
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"skip_rejected": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "If True, flagged clips are removed from the output dataset. "
|
|
||||||
"If False, all clips pass through but the report still lists issues.",
|
|
||||||
}),
|
|
||||||
"min_snr_db": ("FLOAT", {
|
|
||||||
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
|
|
||||||
"tooltip": "Clips with estimated SNR below this value are flagged.",
|
|
||||||
}),
|
|
||||||
"check_codec_artifacts": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET, "STRING")
|
|
||||||
RETURN_NAMES = ("dataset", "report")
|
|
||||||
FUNCTION = "inspect"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Analyze each clip for clipping, low SNR, and codec artifacts. "
|
|
||||||
"Outputs a filtered AUDIO_DATASET and a text report. "
|
|
||||||
"Connect report to a ShowText node to preview in the UI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float, check_codec_artifacts: bool):
|
|
||||||
clean = []
|
|
||||||
flagged = []
|
|
||||||
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
issues = []
|
|
||||||
|
|
||||||
# Clipping
|
|
||||||
peak = wav.abs().max().item()
|
|
||||||
if peak > 0.99:
|
|
||||||
issues.append(f"clipping (peak={peak:.3f})")
|
|
||||||
|
|
||||||
# Low SNR
|
|
||||||
snr = _estimate_snr(wav)
|
|
||||||
if snr < min_snr_db:
|
|
||||||
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
|
|
||||||
|
|
||||||
# Codec artifacts
|
|
||||||
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
|
||||||
issues.append("codec artifact (HF shelf > 15 kHz)")
|
|
||||||
|
|
||||||
if issues:
|
|
||||||
flagged.append(name)
|
|
||||||
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
|
||||||
if not skip_rejected:
|
|
||||||
clean.append(item)
|
|
||||||
else:
|
|
||||||
clean.append(item)
|
|
||||||
lines.append(f" OK {name}")
|
|
||||||
|
|
||||||
lines.append("=" * 40)
|
|
||||||
lines.append(
|
|
||||||
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
|
|
||||||
+ (" (removed)" if skip_rejected else " (kept)")
|
|
||||||
)
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(f"[DatasetInspector]\n{report}", flush=True)
|
|
||||||
return (clean, report)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetInspector
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
clean, report = SelvaDatasetInspector().inspect(ds, skip_rejected=False, min_snr_db=15.0, check_codec_artifacts=True)
|
|
||||||
print(report)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: report with per-clip OK/FLAGGED lines and summary counts.
|
|
||||||
|
|
||||||
**Step 4: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetInspector node (codec artifacts, SNR, clipping)"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 6: SelvaDatasetItemExtractor
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/selva_dataset_pipeline.py`
|
|
||||||
|
|
||||||
**Step 1: Add the extractor class**
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SelvaDatasetItemExtractor:
|
|
||||||
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
|
|
||||||
|
|
||||||
Bridges the dataset pipeline to any node that accepts a standard AUDIO
|
|
||||||
input — save audio, HF Smoother, Spectral Matcher, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0, "min": 0, "max": 9999,
|
|
||||||
"tooltip": "0-based index. Wraps around if index >= dataset length.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("audio", "name", "total")
|
|
||||||
FUNCTION = "extract"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Extract one clip from an AUDIO_DATASET by index. "
|
|
||||||
"Returns standard AUDIO (compatible with all audio nodes), "
|
|
||||||
"the clip name, and the total dataset length."
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract(self, dataset, index: int):
|
|
||||||
if not dataset:
|
|
||||||
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
|
|
||||||
idx = index % len(dataset)
|
|
||||||
item = dataset[idx]
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
print(
|
|
||||||
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
|
|
||||||
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (audio, item["name"], len(dataset))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Smoke test**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
from nodes.selva_dataset_pipeline import SelvaDatasetLoader, SelvaDatasetItemExtractor
|
|
||||||
ds, = SelvaDatasetLoader().load('/media/unraid/davinci/Selva/BJ')
|
|
||||||
audio, name, total = SelvaDatasetItemExtractor().extract(ds, 0)
|
|
||||||
print(name, total, audio['waveform'].shape, audio['sample_rate'])
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/selva_dataset_pipeline.py
|
|
||||||
git commit -m "feat: add SelvaDatasetItemExtractor node"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Task 7: Register all nodes in __init__.py
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `nodes/__init__.py:4-25`
|
|
||||||
|
|
||||||
**Step 1: Add the 5 new entries to `_NODES`**
|
|
||||||
|
|
||||||
Add inside the `_NODES` dict, after `"SelvaDittoOptimizer"`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
|
|
||||||
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
|
|
||||||
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
|
|
||||||
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
|
|
||||||
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify registration**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python3 -c "
|
|
||||||
import sys; sys.path.insert(0, '/media/p5/Comfyui-Prismaudio')
|
|
||||||
from nodes import NODE_CLASS_MAPPINGS
|
|
||||||
keys = [k for k in NODE_CLASS_MAPPINGS if 'Dataset' in k]
|
|
||||||
print(keys)
|
|
||||||
"
|
|
||||||
```
|
|
||||||
Expected: list of 5 dataset node keys.
|
|
||||||
|
|
||||||
**Step 3: Final commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add nodes/__init__.py
|
|
||||||
git commit -m "feat: register audio dataset pipeline nodes in __init__.py"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
5 nodes in `nodes/selva_dataset_pipeline.py`, all registered in `__init__.py`:
|
|
||||||
|
|
||||||
| Node | In | Out |
|
|
||||||
|------|----|-----|
|
|
||||||
| SelvaDatasetLoader | folder path | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetResampler | AUDIO_DATASET | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetLUFSNormalizer | AUDIO_DATASET | AUDIO_DATASET |
|
|
||||||
| SelvaDatasetInspector | AUDIO_DATASET | AUDIO_DATASET + STRING |
|
|
||||||
| SelvaDatasetItemExtractor | AUDIO_DATASET + index | AUDIO + name + total |
|
|
||||||
|
|
||||||
Dependencies: `pyloudnorm`, `soxr` — both confirmed present in the ComfyUI env.
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "alpha_scale_sweep",
|
|
||||||
"description": "Fix LoRA noise contamination (flatness 0.013→0.094 at alpha=rank). Root cause: alpha=rank (scale=1.0) at high rank drowns base model priors. Testing dramatically lower alpha to nudge rather than overwrite. All runs at lr=3e-4 (best stable LR from r128_sweet_spot).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/alpha_scale_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 6000,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r16_alpha4",
|
|
||||||
"group": "conservative",
|
|
||||||
"description": "Back to basics: rank=16 alpha=4 (scale=0.25). Small adapter, gentle scale — cleanest possible LoRA signal.",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 4.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r16_alpha16",
|
|
||||||
"group": "conservative",
|
|
||||||
"description": "rank=16 alpha=16 (scale=1.0) — the original default. Reference point: is the noise issue rank-specific or universal?",
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r32_alpha8",
|
|
||||||
"group": "mid",
|
|
||||||
"description": "rank=32 alpha=8 (scale=0.25). More capacity than r16 but still gentle scale.",
|
|
||||||
"rank": 32,
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r32_alpha32",
|
|
||||||
"group": "mid",
|
|
||||||
"description": "rank=32 alpha=32 (scale=1.0). Same rank, full scale — isolates whether scale or rank is causing noise.",
|
|
||||||
"rank": 32,
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha8",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=8 (scale=0.0625). High capacity, very gentle contribution — can r128 stay clean at low alpha?",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha16",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=16 (scale=0.125). Slightly more signal than alpha=8.",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r128_alpha32",
|
|
||||||
"group": "high_rank_low_alpha",
|
|
||||||
"description": "rank=128 alpha=32 (scale=0.25). Same scale as r16_alpha4 and r32_alpha8 — comparable contribution across ranks.",
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 32.0
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_disc_fm_retest",
|
|
||||||
"description": "Retest discriminator feature matching after bfloat16 dtype fix. Uses optimal config from overnight sweep (snake_alpha, GAFilter, lr=1e-4, phase=1.0, L2-SP=1e-3, 5000 steps).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_disc_fm_retest",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 5000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "snake_5k_control",
|
|
||||||
"description": "Control: best config from overnight sweep without discriminator. Baseline for A/B comparison."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "disc_fm_5k",
|
|
||||||
"description": "Discriminator feature matching at 5k steps. Tests if perceptual FM loss improves over mel+phase alone.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_optimized_dataset",
|
|
||||||
"description": "BigVGAN fine-tuning on optimized dataset (134 clips, 44.1kHz, LUFS-normalized). Standard mode (no LoRA) — trains decoder to faithfully reconstruct target domain audio from mel spectrograms. Uses optimal config from prior sweeps.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_optimized_dataset",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 5000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "standard_5k",
|
|
||||||
"description": "Standard mode: mel from clean FLAC → BigVGAN → reconstruct FLAC. No LoRA. Directly improves VAE roundtrip quality."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "disc_fm_5k",
|
|
||||||
"description": "Standard mode + discriminator feature matching. Tests if perceptual loss helps on clean audio reconstruction.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "standard_10k",
|
|
||||||
"description": "Extended 10k steps. More data passes on 134 clips may extract more from the optimized dataset.",
|
|
||||||
"steps": 10000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "bigvgan_overnight",
|
|
||||||
"description": "BigVGAN vocoder quality sweep. Axes: snake_alpha steps, all_params short run, GAFilter on/off, discriminator FM, phase loss weight. All use LoRA-distorted mels as input so vocoder learns to fix LoRA artifacts.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_overnight",
|
|
||||||
"base": {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 3000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 8,
|
|
||||||
"segment_seconds": 0.5,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": true,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"lora_adapter": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep/standard_baseline/adapter_final.pt"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "snake_3k_baseline",
|
|
||||||
"description": "Snake alpha + GAFilter baseline. 3000 steps, same as first successful run but longer."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_5k",
|
|
||||||
"description": "Snake alpha + GAFilter, 5000 steps. Test if longer training improves further.",
|
|
||||||
"steps": 5000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_no_gafilter",
|
|
||||||
"description": "Snake alpha only, no GAFilter. Isolate GAFilter contribution.",
|
|
||||||
"use_gafilter": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_no_phase",
|
|
||||||
"description": "Snake alpha + GAFilter, no phase loss. Isolate phase loss contribution.",
|
|
||||||
"lambda_phase": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_phase_2",
|
|
||||||
"description": "Snake alpha + GAFilter, phase weight 2.0. Stronger phase penalty.",
|
|
||||||
"lambda_phase": 2.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_lr5e-5",
|
|
||||||
"description": "Snake alpha + GAFilter, lower LR 5e-5. Test if slower converges better.",
|
|
||||||
"lr": 5e-5,
|
|
||||||
"steps": 5000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "snake_disc_fm",
|
|
||||||
"description": "Snake alpha + GAFilter + discriminator feature matching. Perceptual loss should directly penalize harmonic smearing.",
|
|
||||||
"discriminator_path": "/media/unraid/davinci/Selva/BJ/experiment/bigvgan_discriminator_optimizer.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "all_2k_l2sp1e-2",
|
|
||||||
"description": "All params, 2000 steps, strong L2-SP (1e-2). Test if full param tuning with heavy anchor beats snake-only.",
|
|
||||||
"train_mode": "all_params",
|
|
||||||
"steps": 2000,
|
|
||||||
"lr": 1e-5,
|
|
||||||
"lambda_l2sp": 1e-2
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "eval_r128_candidates",
|
|
||||||
"description": "Top candidates from r128_sweet_spot. Comparing the two lowest-loss runs, the stable lr=3e-4, and the curriculum run that hit 0.161 before regressing. Baseline included as perceptual reference.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_dir": "/media/unraid/davinci/Selva/BJ/evals/r128_candidates",
|
|
||||||
"steps": 25,
|
|
||||||
"seed": 42,
|
|
||||||
"adapters": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "No LoRA — base model output for perceptual reference"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_5e4_r128",
|
|
||||||
"description": "Best loss overall (0.137), still descending at step 10k",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_5e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_3e4_r256",
|
|
||||||
"description": "Tied with lr_5e4 at 0.139, higher rank — does extra capacity help perceptually?",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g4_r256_lr_3e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lr_3e4_r128",
|
|
||||||
"description": "Stable plateau from step 4k to 10k (0.221) — visually confirmed clean spectrograms",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g1_r128_lr_3e4/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_lr_3e4",
|
|
||||||
"description": "Best min loss of all (0.161 at step 6k), regressed to 0.193 after curriculum switch — curious if the early checkpoint sounds better",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_final.pt"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_lr_3e4_step6000",
|
|
||||||
"description": "Same run at its actual best step (before regression) — compare against adapter_final to hear the regression",
|
|
||||||
"path": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot/g2_r128_lr_3e4_curriculum/adapter_step06000.pt"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "lora_logit_cosine_combo",
|
|
||||||
"description": "Combine the two best findings from optimized dataset sweep: logit-normal timestep sampling + cosine LR schedule. Both individually outperformed baseline by large margins (56% and 68% lower loss). Tests if gains stack.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_logit_cosine_combo",
|
|
||||||
"base": {
|
|
||||||
"rank": 128,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"steps": 5000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "logit_normal_cosine",
|
|
||||||
"description": "Logit-normal timesteps + cosine LR decay. Combining the two best individual improvements.",
|
|
||||||
"timestep_mode": "logit_normal",
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "logit_normal_control",
|
|
||||||
"description": "Control: logit-normal only (constant LR). Reproduces previous winner for direct comparison.",
|
|
||||||
"timestep_mode": "logit_normal"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "lora_optimized_dataset",
|
|
||||||
"description": "LoRA training on optimized dataset (134 clips: resampled 44.1kHz, LUFS-normalized, spectral matched, HF smoothed, gain-augmented). Tests latent augmentation and schedule variants on top of known-best config (PiSSA, rank=128, lr=3e-4).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features_v2_improved/",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/lora_optimized_dataset",
|
|
||||||
"base": {
|
|
||||||
"rank": 128,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"steps": 5000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "Control: known-best config (PiSSA r128 lr=3e-4) on the optimized dataset. No latent augmentation."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "latent_mixup",
|
|
||||||
"description": "Latent mixup alpha=0.4 (MusicLDM). Tests if mixing training latents reduces memorization on 134 clips.",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "latent_noise",
|
|
||||||
"description": "Latent noise sigma=0.02. Mild Gaussian noise on training latents for regularization.",
|
|
||||||
"latent_noise_sigma": 0.02
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "mixup_and_noise",
|
|
||||||
"description": "Both latent mixup (0.4) and noise (0.02). Combined regularization.",
|
|
||||||
"latent_mixup_alpha": 0.4,
|
|
||||||
"latent_noise_sigma": 0.02
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cosine_schedule",
|
|
||||||
"description": "Cosine LR decay. lr=3e-4 was stable with constant, but cosine may extract more from 5k steps.",
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cosine_mixup",
|
|
||||||
"description": "Cosine LR + latent mixup. Best regularization combo candidate.",
|
|
||||||
"lr_schedule": "cosine",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "logit_normal",
|
|
||||||
"description": "Logit-normal timestep sampling (sigma=1.0). Concentrates training near t=0.5 where flow matching is hardest.",
|
|
||||||
"timestep_mode": "logit_normal"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum_mixup",
|
|
||||||
"description": "Curriculum timesteps (logit_normal first 60%, then uniform) + latent mixup. Full regularization stack.",
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"latent_mixup_alpha": 0.4
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "pissa_sweep",
|
|
||||||
"description": "PiSSA vs standard init ablation at rank 128. Best prior config (lr=3e-4, bs=16, 10k steps) as baseline. PiSSA starts on-manifold via SVD init — should eliminate intruder dimensions. rsLoRA stabilises scaling at high rank.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/pissa_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant",
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": true
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "standard_baseline",
|
|
||||||
"description": "Standard Kaiming init + classic alpha/rank scaling. Replicates best prior config for A/B comparison.",
|
|
||||||
"init_mode": "standard",
|
|
||||||
"use_rslora": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora",
|
|
||||||
"description": "PiSSA init + rsLoRA scaling. Full Tier-S config. Should start on-manifold and avoid intruder dimensions."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_classic_scale",
|
|
||||||
"description": "PiSSA init + classic alpha/rank scaling. Isolates PiSSA contribution from rsLoRA.",
|
|
||||||
"use_rslora": false
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "standard_rslora",
|
|
||||||
"description": "Standard init + rsLoRA only. Isolates rsLoRA contribution from PiSSA.",
|
|
||||||
"init_mode": "standard"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_lr1e-4",
|
|
||||||
"description": "PiSSA+rsLoRA at lower lr=1e-4. PiSSA starts closer to optimum — may need less aggressive lr.",
|
|
||||||
"lr": 1e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_lr5e-4",
|
|
||||||
"description": "PiSSA+rsLoRA at higher lr=5e-4. Test if on-manifold start tolerates faster learning.",
|
|
||||||
"lr": 5e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "pissa_rslora_dropout",
|
|
||||||
"description": "PiSSA+rsLoRA with dropout 0.05. Note: PiSSA forces dropout=0 (principal components should not be dropped) — this tests standard init with rsLoRA + dropout.",
|
|
||||||
"init_mode": "standard",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "r128_sweet_spot",
|
|
||||||
"description": "Find the noise-free sweet spot on rank 128. LoRA+ ratio=16 caused noise — testing higher base LR without LoRA+ as a cleaner alternative. Target loss range 0.25–0.35. Also probing rank 256 since 102GB VRAM allows it.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r128_sweet_spot",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_2e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=2e-4. Conservative 2× step up from baseline — noise-free descent toward sweet spot.",
|
|
||||||
"lr": 2e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_3e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4. 3× baseline — landed at 0.41 on r64, should reach 0.25–0.35 on r128.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_lr_5e4",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=5e-4. Aggressive but no LoRA+ B-matrix asymmetry — cleaner noise profile.",
|
|
||||||
"lr": 5e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r128_curriculum",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "Curriculum only at baseline LR. Clean slow descent — reference for what curriculum contributes alone.",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r128_lr_3e4_curriculum",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "LR=3e-4 + curriculum. Speed of higher LR with coverage of curriculum — no LoRA+.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r128_lr_3e4_curriculum_dropout",
|
|
||||||
"group": "curriculum",
|
|
||||||
"description": "LR=3e-4 + curriculum + dropout=0.05. Full controlled stack without LoRA+.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r128_lora_plus_4",
|
|
||||||
"group": "lora_plus",
|
|
||||||
"description": "LoRA+ ratio=4 (lr_B=4e-4). Much more conservative than ratio=16 — tests if noise came from ratio not the technique.",
|
|
||||||
"lora_plus_ratio": 4.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_r256_baseline",
|
|
||||||
"group": "rank256",
|
|
||||||
"description": "Rank 256 at baseline LR. 102GB VRAM makes this viable — does more capacity keep helping?",
|
|
||||||
"rank": 256
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_r256_lr_3e4",
|
|
||||||
"group": "rank256",
|
|
||||||
"description": "Rank 256 + LR=3e-4. Best rank + best LR candidate combined.",
|
|
||||||
"rank": 256,
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_r128_lr_2e4_cosine",
|
|
||||||
"group": "cosine",
|
|
||||||
"description": "LR=2e-4 + cosine decay. Fixes the oscillation observed at step 6000–8000 by decaying LR to ~0 instead of staying flat.",
|
|
||||||
"lr": 2e-4,
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_r128_lr_3e4_cosine",
|
|
||||||
"group": "cosine",
|
|
||||||
"description": "LR=3e-4 + cosine decay. Higher LR with decay — should reach lower loss faster then lock in.",
|
|
||||||
"lr": 3e-4,
|
|
||||||
"lr_schedule": "cosine"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "r64_overnight",
|
|
||||||
"description": "Focused rank-64 overnight sweep. All experiments use rank 64 as base — confirmed best from tier1_thorough early results. 8000 steps to reach convergence (none converged at 4000).",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/r64_overnight",
|
|
||||||
"base": {
|
|
||||||
"steps": 8000,
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g1_r64_baseline",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 64 baseline — clean reference at 8000 steps."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_r128_baseline",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 128 — 102GB VRAM makes this free. Does doubling rank from 64 help further?",
|
|
||||||
"rank": 128
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g2_r64_alpha_32",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Rank 64 alpha=32 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_r64_alpha_16",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Rank 64 alpha=16 (scale=0.25). More aggressive scale reduction — may over-constrain.",
|
|
||||||
"alpha": 16.0
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g3_r64_lora_plus",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=16. lr_B = 16 × lr_A. Faster convergence at constant step budget.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_dropout_0.05",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Dropout=0.05. Light sparsity regularisation on LoRA path.",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_dropout_0.1",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Dropout=0.1. Stronger regularisation — tests if 49 clips needs heavier constraint.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_r64_curriculum",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Curriculum sampling: logit_normal for steps 1-4800, then uniform (arXiv:2603.12517).",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_r64_lr_low",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-5. 3× lower — checks if 1e-4 is overshooting at rank 64.",
|
|
||||||
"lr": 3e-5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_r64_lr_high",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4. 3× higher — may converge faster but risk instability.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_r64_target_full",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 64 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_r128_target_full",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 128 + full target. Maximum possible coverage with available VRAM.",
|
|
||||||
"rank": 128,
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g6_r64_full_tier1",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64: LoRA+ 16 + dropout 0.05 + curriculum. Full stack at 8000 steps.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_r64_alpha32_full",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "Rank 64 alpha=32 + all Tier 1. Best alpha scaling + best regularisation stack.",
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_r128_full_tier1",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "Rank 128 + all Tier 1. Tests if more capacity + regularisation beats rank 64 full.",
|
|
||||||
"rank": 128,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "ti_sweep_1",
|
|
||||||
"description": "First TI sweep. n4_baseline (suffix, batch=16, lr=1e-3) completed — buzz artifact diagnosed as token norm drifting to 3.2x outside CLIP manifold. All new experiments use norm clamping (auto from dataset) + corrected lr/batch.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/ti_sweep_1",
|
|
||||||
"base": {
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"init_text": "",
|
|
||||||
"lr": 2e-4,
|
|
||||||
"n_tokens": 4,
|
|
||||||
"inject_mode": "suffix"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n4_baseline",
|
|
||||||
"group": "reference",
|
|
||||||
"description": "COMPLETED (old code, no norm clamp). batch=16, lr=1e-3. Token norm drifted to 3.2 → buzz artifact. Kept for loss curve comparison only."
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "n4_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "Same as baseline but with norm clamping enabled. Primary diagnostic: does clamping alone fix the buzz? lr=2e-4, batch=4, suffix."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n4_prefix_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "Prefix injection + norm clamping. Best of both: high-attention positions, tokens stay on CLIP manifold.",
|
|
||||||
"inject_mode": "prefix"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n8_prefix_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "8 tokens, prefix, clamped. More capacity without the artifact.",
|
|
||||||
"n_tokens": 8,
|
|
||||||
"inject_mode": "prefix"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "n4_prefix_warm_clamped",
|
|
||||||
"group": "norm_clamp",
|
|
||||||
"description": "4 tokens, prefix, warm init from 'mechanical impact sound design', clamped. Should converge fastest — starts in-manifold, stays in-manifold.",
|
|
||||||
"inject_mode": "prefix",
|
|
||||||
"init_text": "mechanical impact sound design"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "tier1_sweep",
|
|
||||||
"description": "Ablation of Tier 1 improvements: LoRA+, dropout, curriculum sampling. Baseline = uniform, no regularisation.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "lora_sweeps/tier1_sweep",
|
|
||||||
"base": {
|
|
||||||
"steps": 4000,
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 500,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "baseline",
|
|
||||||
"description": "Standard LoRA — no Tier 1 changes. Reference point."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "lora_plus_16",
|
|
||||||
"description": "LoRA+ only: lr_B = 16 * lr_A. Should converge faster in early steps.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "dropout_0.05",
|
|
||||||
"description": "LoRA dropout 0.05 only. Light regularisation for 49-clip dataset.",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "dropout_0.1",
|
|
||||||
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "curriculum",
|
|
||||||
"description": "Curriculum sampling only: logit_normal for steps 1-2400, then uniform. Should improve convergence vs pure uniform.",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "full_tier1",
|
|
||||||
"description": "All Tier 1 combined: LoRA+ + dropout 0.05 + curriculum.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "rank_64",
|
|
||||||
"description": "Rank 64 baseline — MMAudio LoRA guide default. More expressive adapter for 49-clip dataset.",
|
|
||||||
"rank": 64
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "tier1_thorough",
|
|
||||||
"description": "Full overnight Tier 1 ablation on 49-clip BJ dataset. 4 groups: rank, alpha, regularisation, and best combinations. ~10-12h depending on GPU.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/tier1_thorough",
|
|
||||||
"base": {
|
|
||||||
"steps": 4000,
|
|
||||||
"rank": 16,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 1000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "g1_rank_16",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 16 baseline — reference point for all groups."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_rank_32",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 32 — midpoint. Does doubling rank improve quality without overfitting?",
|
|
||||||
"rank": 32
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g1_rank_64",
|
|
||||||
"group": "rank",
|
|
||||||
"description": "Rank 64 — MMAudio LoRA guide default. Maximum expressiveness at 49 clips.",
|
|
||||||
"rank": 64
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_alpha_half_r16",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Alpha=8 with rank 16 (scale=0.5). Reduces intruder singular dimensions (arXiv:2410.21228).",
|
|
||||||
"alpha": 8.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g2_alpha_half_r64",
|
|
||||||
"group": "alpha",
|
|
||||||
"description": "Alpha=32 with rank 64 (scale=0.5). Best-practice scaling for high-rank adapters.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_lora_plus_4",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=4 — conservative asymmetric LR. Lower bound for the technique.",
|
|
||||||
"lora_plus_ratio": 4.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_lora_plus_16",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA+ ratio=16 — standard from FLUX LoRA literature. Faster early convergence.",
|
|
||||||
"lora_plus_ratio": 16.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_dropout_0.05",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA dropout 0.05 only. Light sparsity regularisation (arXiv:2404.09610).",
|
|
||||||
"lora_dropout": 0.05
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_dropout_0.1",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "LoRA dropout 0.1 only. Stronger regularisation — may prevent overfitting past step 2000.",
|
|
||||||
"lora_dropout": 0.1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g3_curriculum",
|
|
||||||
"group": "regularisation",
|
|
||||||
"description": "Curriculum sampling only: logit_normal steps 1-2400, then uniform (arXiv:2603.12517).",
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_full_r16",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 16: LoRA+ 16 + dropout 0.05 + curriculum.",
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g4_full_r64",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64 + alpha=32. Best expressiveness + best regularisation.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g5_lr_low",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-5 — 3× lower than baseline. Tests if 1e-4 is overshooting.",
|
|
||||||
"lr": 3e-5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g5_lr_high",
|
|
||||||
"group": "lr",
|
|
||||||
"description": "LR=3e-4 — 3× higher than baseline. Tests if 1e-4 is too conservative.",
|
|
||||||
"lr": 3e-4
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g6_target_full_r16",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 16 targeting attn.qkv + linear1 (FFN projections). Doubles LoRA coverage.",
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "g6_target_full_r64",
|
|
||||||
"group": "target",
|
|
||||||
"description": "Rank 64 + alpha=32 targeting attn.qkv + linear1. Maximum coverage + expressiveness.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"target": "attn.qkv linear1"
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
"id": "g4_full_r64_6k",
|
|
||||||
"group": "combined",
|
|
||||||
"description": "All Tier 1 at rank 64 + alpha=32, extended to 6000 steps. Checks if convergence is done at 4000.",
|
|
||||||
"rank": 64,
|
|
||||||
"alpha": 32.0,
|
|
||||||
"lora_plus_ratio": 16.0,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"timestep_mode": "curriculum",
|
|
||||||
"steps": 6000,
|
|
||||||
"save_every": 1000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
{
|
|
||||||
"name": "vocoder_finetune",
|
|
||||||
"description": "Single run with fine-tuned BJ BigVGAN vocoder injected. Validates vocoder integration with LoRA training. Best known config: lr=3e-4, rank=128.",
|
|
||||||
"data_dir": "/media/unraid/davinci/Selva/BJ/features",
|
|
||||||
"output_root": "/media/unraid/davinci/Selva/BJ/experiment/vocoder_finetune",
|
|
||||||
"base": {
|
|
||||||
"steps": 10000,
|
|
||||||
"rank": 128,
|
|
||||||
"alpha": 0.0,
|
|
||||||
"lr": 3e-4,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 200,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 2000,
|
|
||||||
"seed": 42,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant"
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{
|
|
||||||
"id": "r128_lr_3e4_bj_vocoder",
|
|
||||||
"description": "lr=3e-4 rank=128 with fine-tuned BJ BigVGAN vocoder. Direct comparison baseline against previous best g1_r128_lr_3e4."
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
+8
-34
@@ -2,39 +2,13 @@ NODE_CLASS_MAPPINGS = {}
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
_NODES = {
|
_NODES = {
|
||||||
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"),
|
||||||
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"),
|
||||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
||||||
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
||||||
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
||||||
"SelvaLoraScheduler": (".selva_lora_scheduler", "SelvaLoraScheduler", "SelVA LoRA Scheduler"),
|
"PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"),
|
||||||
"SelvaDatasetBrowser": (".selva_dataset_browser", "SelvaDatasetBrowser", "SelVA Dataset Browser"),
|
"PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"),
|
||||||
"SelvaSkipExperiment": (".selva_skip_experiment", "SelvaSkipExperiment", "SelVA Skip Experiment"),
|
|
||||||
"SelvaLoraEvaluator": (".selva_lora_evaluator", "SelvaLoraEvaluator", "SelVA LoRA Evaluator"),
|
|
||||||
"SelvaVaeRoundtrip": (".selva_vae_roundtrip", "SelvaVaeRoundtrip", "SelVA VAE Roundtrip"),
|
|
||||||
"SelvaHfSmoother": (".selva_audio_preprocessors", "SelvaHfSmoother", "SelVA HF Smoother"),
|
|
||||||
"SelvaSpectralMatcher": (".selva_audio_preprocessors", "SelvaSpectralMatcher", "SelVA Spectral Matcher"),
|
|
||||||
"SelvaTextualInversionTrainer": (".selva_textual_inversion_trainer", "SelvaTextualInversionTrainer", "SelVA Textual Inversion Trainer"),
|
|
||||||
"SelvaTextualInversionLoader": (".selva_textual_inversion_loader", "SelvaTextualInversionLoader", "SelVA Textual Inversion Loader"),
|
|
||||||
"SelvaTiScheduler": (".selva_ti_scheduler", "SelvaTiScheduler", "SelVA TI Scheduler"),
|
|
||||||
"SelvaActivationSteeringExtractor": (".selva_activation_steering_extractor", "SelvaActivationSteeringExtractor", "SelVA Activation Steering Extractor"),
|
|
||||||
"SelvaActivationSteeringLoader": (".selva_activation_steering_loader", "SelvaActivationSteeringLoader", "SelVA Activation Steering Loader"),
|
|
||||||
"SelvaBigvganTrainer": (".selva_bigvgan_trainer", "SelvaBigvganTrainer", "SelVA BigVGAN Trainer"),
|
|
||||||
"SelvaBigvganLoader": (".selva_bigvgan_loader", "SelvaBigvganLoader", "SelVA BigVGAN Loader"),
|
|
||||||
"SelvaBigvganScheduler": (".selva_bigvgan_scheduler", "SelvaBigvganScheduler", "SelVA BigVGAN Scheduler"),
|
|
||||||
"SelvaDittoOptimizer": (".selva_ditto_optimizer", "SelvaDittoOptimizer", "SelVA DITTO Optimizer"),
|
|
||||||
"SelvaDatasetLoader": (".selva_dataset_pipeline", "SelvaDatasetLoader", "SelVA Dataset Loader"),
|
|
||||||
"SelvaDatasetResampler": (".selva_dataset_pipeline", "SelvaDatasetResampler", "SelVA Dataset Resampler"),
|
|
||||||
"SelvaDatasetLUFSNormalizer": (".selva_dataset_pipeline", "SelvaDatasetLUFSNormalizer", "SelVA Dataset LUFS Normalizer"),
|
|
||||||
"SelvaDatasetCompressor": (".selva_dataset_pipeline", "SelvaDatasetCompressor", "SelVA Dataset Compressor"),
|
|
||||||
"SelvaDatasetInspector": (".selva_dataset_pipeline", "SelvaDatasetInspector", "SelVA Dataset Inspector"),
|
|
||||||
"SelvaDatasetItemExtractor": (".selva_dataset_pipeline", "SelvaDatasetItemExtractor", "SelVA Dataset Item Extractor"),
|
|
||||||
"SelvaDatasetSaver": (".selva_dataset_pipeline", "SelvaDatasetSaver", "SelVA Dataset Saver"),
|
|
||||||
"SelvaHarmonicExciter": (".selva_audio_postprocess", "SelvaHarmonicExciter", "SelVA Harmonic Exciter"),
|
|
||||||
"SelvaOutputNormalizer": (".selva_audio_postprocess", "SelvaOutputNormalizer", "SelVA Output Normalizer"),
|
|
||||||
"SelvaDatasetSpectralMatcher": (".selva_dataset_pipeline", "SelvaDatasetSpectralMatcher", "SelVA Dataset Spectral Matcher"),
|
|
||||||
"SelvaDatasetHfSmoother": (".selva_dataset_pipeline", "SelvaDatasetHfSmoother", "SelVA Dataset HF Smoother"),
|
|
||||||
"SelvaDatasetAugmenter": (".selva_dataset_pipeline", "SelvaDatasetAugmenter", "SelVA Dataset Augmenter"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
@@ -44,4 +18,4 @@ for key, (module_path, class_name, display_name) in _NODES.items():
|
|||||||
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
print(f"[SelVA] Skipping {key}: {e}")
|
print(f"[PrismAudio] Skipping {key}: {e}")
|
||||||
|
|||||||
@@ -0,0 +1,228 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import hashlib
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
from .feature_loader import PrismAudioFeatureLoader
|
||||||
|
|
||||||
|
# Managed venv created automatically when python_env is left as default
|
||||||
|
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||||
|
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||||
|
|
||||||
|
def _jax_package():
|
||||||
|
"""Return the correct jax extra for the current CUDA version."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_ver = torch.version.cuda or ""
|
||||||
|
major = int(cuda_ver.split(".")[0]) if cuda_ver else 0
|
||||||
|
if major >= 13:
|
||||||
|
return "jax[cuda13]"
|
||||||
|
elif major >= 12:
|
||||||
|
return "jax[cuda12]"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return "jax" # CPU fallback
|
||||||
|
|
||||||
|
|
||||||
|
_EXTRACT_PACKAGES = [
|
||||||
|
"torch", "torchaudio", "torchvision",
|
||||||
|
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||||
|
"tensorflow-cpu>=2.16.0",
|
||||||
|
# jax CUDA extra is resolved at install time based on detected CUDA version
|
||||||
|
_jax_package(), "flax",
|
||||||
|
"transformers", "decord", "einops", "numpy",
|
||||||
|
"git+https://github.com/google-deepmind/videoprism.git",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _pip_install(pip, *packages, label=None):
|
||||||
|
"""Install one or more packages with visible output; raise on failure."""
|
||||||
|
tag = label or packages[0]
|
||||||
|
print(f"[PrismAudio] installing {tag} ...", flush=True)
|
||||||
|
result = subprocess.run(
|
||||||
|
[pip, "install", "--progress-bar", "on"] + list(packages),
|
||||||
|
capture_output=False,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
|
||||||
|
"See pip output above for details."
|
||||||
|
)
|
||||||
|
print(f"[PrismAudio] {tag} OK", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_extract_env():
|
||||||
|
"""Create and populate the managed venv on first use."""
|
||||||
|
if os.path.exists(_MANAGED_PYTHON):
|
||||||
|
return _MANAGED_PYTHON
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(_MANAGED_VENV):
|
||||||
|
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
|
||||||
|
shutil.rmtree(_MANAGED_VENV)
|
||||||
|
|
||||||
|
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
|
||||||
|
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
|
||||||
|
|
||||||
|
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
|
||||||
|
|
||||||
|
print("[PrismAudio] Upgrading pip...", flush=True)
|
||||||
|
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
|
||||||
|
|
||||||
|
total = len(_EXTRACT_PACKAGES)
|
||||||
|
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
|
||||||
|
|
||||||
|
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
|
||||||
|
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
|
||||||
|
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
|
||||||
|
_pip_install(pip, pkg, label=label)
|
||||||
|
|
||||||
|
print("[PrismAudio] Feature-extraction env ready.", flush=True)
|
||||||
|
return _MANAGED_PYTHON
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_inputs(video_tensor, cot_text, fps):
|
||||||
|
"""Create a hash of the inputs for caching."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||||
|
h.update(cot_text.encode())
|
||||||
|
h.update(str(fps).encode()) # fps affects frame sampling — must be part of the key
|
||||||
|
return h.hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
|
def _save_frames_to_npy(video_tensor, output_path):
|
||||||
|
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
||||||
|
|
||||||
|
Lossless — avoids H.264 encode/decode roundtrip.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
||||||
|
np.save(output_path, frames_np)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioFeatureExtractor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"video": ("IMAGE",),
|
||||||
|
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
|
||||||
|
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
|
||||||
|
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
|
||||||
|
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
||||||
|
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
|
||||||
|
RETURN_NAMES = ("features", "fps")
|
||||||
|
FUNCTION = "extract_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
|
||||||
|
# Resolve fps from VHS video_info if connected
|
||||||
|
if video_info is not None:
|
||||||
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
if not caption_cot.strip():
|
||||||
|
print("[PrismAudio] Warning: caption_cot is empty — text features will be degenerate. "
|
||||||
|
"Provide a descriptive chain-of-thought caption for best results.", flush=True)
|
||||||
|
|
||||||
|
# Resolve python binary
|
||||||
|
if python_env == "comfyui_env":
|
||||||
|
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||||
|
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
|
||||||
|
python_bin = sys.executable
|
||||||
|
else:
|
||||||
|
python_bin = _ensure_extract_env()
|
||||||
|
|
||||||
|
# Determine cache directory
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_hash = _hash_inputs(video, caption_cot, fps)
|
||||||
|
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||||
|
if os.path.exists(cached_path):
|
||||||
|
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||||
|
loader = PrismAudioFeatureLoader()
|
||||||
|
features, = loader.load_features(cached_path)
|
||||||
|
return (features, float(fps))
|
||||||
|
|
||||||
|
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
||||||
|
import time
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
frames = video.shape[0]
|
||||||
|
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
||||||
|
tmp_video = tmp.name
|
||||||
|
_save_frames_to_npy(video, tmp_video)
|
||||||
|
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||||
|
|
||||||
|
# Build subprocess command
|
||||||
|
script_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"scripts", "extract_features.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
|
||||||
|
if not os.path.exists(synchformer_ckpt):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
|
||||||
|
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
python_bin,
|
||||||
|
script_path,
|
||||||
|
"--video", tmp_video,
|
||||||
|
"--cot_text", caption_cot,
|
||||||
|
"--output", cached_path,
|
||||||
|
"--source_fps", str(fps),
|
||||||
|
"--synchformer_ckpt", synchformer_ckpt,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build env: inherit current env, inject HF token if provided
|
||||||
|
import copy
|
||||||
|
env = copy.copy(os.environ)
|
||||||
|
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
|
||||||
|
if token:
|
||||||
|
env["HF_TOKEN"] = token
|
||||||
|
env["HUGGING_FACE_HUB_TOKEN"] = token
|
||||||
|
else:
|
||||||
|
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
|
||||||
|
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
|
||||||
|
|
||||||
|
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
|
||||||
|
try:
|
||||||
|
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=False,
|
||||||
|
timeout=600, # 10 minute timeout
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
|
||||||
|
"See output above for details."
|
||||||
|
)
|
||||||
|
print("[PrismAudio] Feature extraction subprocess finished successfully.")
|
||||||
|
finally:
|
||||||
|
if os.path.exists(tmp_video):
|
||||||
|
os.unlink(tmp_video)
|
||||||
|
|
||||||
|
# Load the extracted features
|
||||||
|
loader = PrismAudioFeatureLoader()
|
||||||
|
features, = loader.load_features(cached_path)
|
||||||
|
return (features, float(fps))
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
||||||
|
# global_video_features and global_text_features are NOT consumed by any conditioner
|
||||||
|
# in the prismaudio.json config — they are unused.
|
||||||
|
REQUIRED_KEYS = [
|
||||||
|
"video_features",
|
||||||
|
"text_features",
|
||||||
|
"sync_features",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioFeatureLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||||
|
RETURN_NAMES = ("features",)
|
||||||
|
FUNCTION = "load_features"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_features(self, npz_path):
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
||||||
|
|
||||||
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
|
|
||||||
|
features = {}
|
||||||
|
for key in REQUIRED_KEYS:
|
||||||
|
if key in data:
|
||||||
|
features[key] = torch.from_numpy(data[key]).float()
|
||||||
|
else:
|
||||||
|
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
||||||
|
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
||||||
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
||||||
|
if key == "sync_features":
|
||||||
|
features[key] = torch.zeros(8, 768)
|
||||||
|
else:
|
||||||
|
features[key] = torch.zeros(1, 1024)
|
||||||
|
|
||||||
|
# Load duration if present
|
||||||
|
if "duration" in data:
|
||||||
|
features["duration"] = float(data["duration"])
|
||||||
|
|
||||||
|
return (features,)
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .utils import PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float):
|
||||||
|
"""Add LoRA delta weights directly into the base model's nn.Linear tensors.
|
||||||
|
|
||||||
|
delta_W = lora_B @ lora_A * scale * strength
|
||||||
|
applied as: linear.weight += delta_W
|
||||||
|
|
||||||
|
This is equivalent to LoRALinear at inference but requires no wrapper,
|
||||||
|
no extra memory, and no change to the model's forward call graph.
|
||||||
|
"""
|
||||||
|
scale = (alpha / rank) * strength
|
||||||
|
|
||||||
|
# Group saved keys by module path
|
||||||
|
a_map = {
|
||||||
|
k.replace(".lora_A.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_A.weight")
|
||||||
|
}
|
||||||
|
b_map = {
|
||||||
|
k.replace(".lora_B.weight", ""): v
|
||||||
|
for k, v in lora_state.items() if k.endswith("lora_B.weight")
|
||||||
|
}
|
||||||
|
|
||||||
|
merged = 0
|
||||||
|
for path, lora_A in a_map.items():
|
||||||
|
if path not in b_map:
|
||||||
|
print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True)
|
||||||
|
continue
|
||||||
|
lora_B = b_map[path] # [out_features, rank]
|
||||||
|
# delta_W: [out_features, in_features]
|
||||||
|
delta_W = (lora_B.float() @ lora_A.float()) * scale
|
||||||
|
|
||||||
|
# Navigate to the parent module using PyTorch's get_submodule
|
||||||
|
*parent_parts, child_name = path.split(".")
|
||||||
|
try:
|
||||||
|
parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit
|
||||||
|
except AttributeError as e:
|
||||||
|
print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear = getattr(parent, child_name, None)
|
||||||
|
if not isinstance(linear, nn.Linear):
|
||||||
|
print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
linear.weight.data.add_(delta_W.to(linear.weight.dtype))
|
||||||
|
merged += 1
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioLoRALoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_lora(self, model, lora_path, strength):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
if not os.path.exists(lora_path):
|
||||||
|
raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}")
|
||||||
|
|
||||||
|
config_path = lora_path.replace(".safetensors", "_config.json")
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[PrismAudio] LoRA config not found: {config_path}\n"
|
||||||
|
"Expected a _config.json alongside the .safetensors file."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
rank = config["rank"]
|
||||||
|
alpha = config["alpha"]
|
||||||
|
|
||||||
|
lora_state = load_file(lora_path)
|
||||||
|
|
||||||
|
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||||
|
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||||
|
# when inputs change, providing a fresh base model as needed.
|
||||||
|
dit = model["model"].model # DiTWrapper
|
||||||
|
|
||||||
|
if strength == 0.0:
|
||||||
|
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
_merge_lora_weights(dit, lora_state, rank, alpha, strength)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
@@ -0,0 +1,284 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LoRA primitives
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
"""Low-rank adapter wrapping a frozen nn.Linear."""
|
||||||
|
|
||||||
|
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = linear
|
||||||
|
self.scale = alpha / rank
|
||||||
|
in_f, out_f = linear.in_features, linear.out_features
|
||||||
|
self.lora_A = nn.Linear(in_f, rank, bias=False)
|
||||||
|
self.lora_B = nn.Linear(rank, out_f, bias=False)
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
_TARGET_MODULE_PRESETS = {
|
||||||
|
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
|
||||||
|
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
|
||||||
|
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
|
||||||
|
"""Recursively replace matching nn.Linear layers with LoRALinear."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, nn.Linear) and name in target_attrs:
|
||||||
|
setattr(module, name, LoRALinear(child, rank, alpha))
|
||||||
|
else:
|
||||||
|
_apply_lora(child, target_attrs, rank, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _unapply_lora(module: nn.Module):
|
||||||
|
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
|
||||||
|
for name, child in list(module.named_children()):
|
||||||
|
if isinstance(child, LoRALinear):
|
||||||
|
child.linear.weight.requires_grad_(False)
|
||||||
|
setattr(module, name, child.linear)
|
||||||
|
else:
|
||||||
|
_unapply_lora(child)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_state_dict(module: nn.Module) -> dict:
|
||||||
|
"""Return only LoRA parameter tensors from a module's state dict."""
|
||||||
|
return {k: v for k, v in module.state_dict().items()
|
||||||
|
if "lora_A" in k or "lora_B" in k}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_dataset(dataset_dir: str):
|
||||||
|
"""Return list of (npz_path, audio_path) pairs matched by stem."""
|
||||||
|
pairs = []
|
||||||
|
for fname in os.listdir(dataset_dir):
|
||||||
|
if not fname.endswith(".npz"):
|
||||||
|
continue
|
||||||
|
stem = os.path.join(dataset_dir, fname[:-4])
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
audio_path = stem + ext
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
pairs.append((stem + ".npz", audio_path))
|
||||||
|
break
|
||||||
|
return sorted(pairs)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
|
||||||
|
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
|
||||||
|
import torchaudio
|
||||||
|
waveform, sr = torchaudio.load(audio_path)
|
||||||
|
if sr != SAMPLE_RATE:
|
||||||
|
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
||||||
|
if waveform.shape[0] == 1:
|
||||||
|
waveform = waveform.expand(2, -1)
|
||||||
|
elif waveform.shape[0] > 2:
|
||||||
|
waveform = waveform[:2]
|
||||||
|
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
|
||||||
|
"""Load .npz features into a conditioner metadata dict."""
|
||||||
|
import numpy as np
|
||||||
|
data = np.load(npz_path, allow_pickle=True)
|
||||||
|
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
|
||||||
|
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
|
||||||
|
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
|
||||||
|
has_video = bool(video_feat.abs().sum() > 0)
|
||||||
|
return {
|
||||||
|
"video_features": video_feat,
|
||||||
|
"text_features": text_feat,
|
||||||
|
"sync_features": sync_feat,
|
||||||
|
"video_exist": torch.tensor(has_video),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trainer node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class PrismAudioLoRATrainer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
|
||||||
|
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
|
||||||
|
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
|
||||||
|
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
|
||||||
|
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
|
||||||
|
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
|
||||||
|
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
|
||||||
|
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
|
||||||
|
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("lora_path",)
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
|
||||||
|
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
diffusion = model["model"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Scan dataset
|
||||||
|
pairs = _scan_dataset(dataset_dir)
|
||||||
|
if not pairs:
|
||||||
|
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
|
||||||
|
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
|
||||||
|
|
||||||
|
# Resolve output path
|
||||||
|
if not output_path:
|
||||||
|
import folder_paths
|
||||||
|
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
diffusion.model.to(device)
|
||||||
|
diffusion.conditioner.to(device)
|
||||||
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
|
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||||
|
dit = diffusion.model # DiTWrapper
|
||||||
|
for p in dit.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
|
||||||
|
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
|
||||||
|
|
||||||
|
# Cast LoRA params to model dtype and move to device
|
||||||
|
for m in dit.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_A.to(device=device, dtype=dtype)
|
||||||
|
m.lora_B.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
trainable = [p for p in dit.parameters() if p.requires_grad]
|
||||||
|
n_params = sum(p.numel() for p in trainable)
|
||||||
|
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
|
||||||
|
|
||||||
|
diffusion.conditioner.eval()
|
||||||
|
diffusion.pretransform.eval()
|
||||||
|
dit.train()
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
|
||||||
|
|
||||||
|
# GradScaler for fp16 to prevent underflow
|
||||||
|
use_scaler = (dtype == torch.float16)
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(train_steps)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for step in range(1, train_steps + 1):
|
||||||
|
npz_path, audio_path = random.choice(pairs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode audio to latent space
|
||||||
|
audio = _load_audio(audio_path, device)
|
||||||
|
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||||
|
|
||||||
|
# Build conditioning from features
|
||||||
|
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||||
|
conditioning = diffusion.conditioner(metadata, device)
|
||||||
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
# Rectified flow: interpolate between data and noise
|
||||||
|
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||||
|
t_bcast = t[:, None, None]
|
||||||
|
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||||
|
v_target = noise - x0
|
||||||
|
|
||||||
|
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
v_pred = dit(x_t, t,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob=cfg_dropout_prob,
|
||||||
|
**cond_inputs)
|
||||||
|
|
||||||
|
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||||
|
|
||||||
|
if use_scaler:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % 50 == 0:
|
||||||
|
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||||
|
|
||||||
|
if step % save_every == 0:
|
||||||
|
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||||
|
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||||
|
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Save final weights
|
||||||
|
save_file(_get_lora_state_dict(dit), output_path)
|
||||||
|
|
||||||
|
# Save config alongside weights so the loader knows the structure
|
||||||
|
config_path = output_path.replace(".safetensors", "_config.json")
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump({
|
||||||
|
"rank": lora_rank,
|
||||||
|
"alpha": lora_alpha,
|
||||||
|
"target_modules": sorted(target_attrs),
|
||||||
|
}, f, indent=2)
|
||||||
|
|
||||||
|
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Always restore model to base state — even on exception.
|
||||||
|
# Without this, LoRA wrappers would persist in the cached model and
|
||||||
|
# subsequent training runs would apply LoRA on top of existing LoRA.
|
||||||
|
dit.eval()
|
||||||
|
_unapply_lora(dit)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(get_offload_device())
|
||||||
|
diffusion.conditioner.to(get_offload_device())
|
||||||
|
diffusion.pretransform.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
return (output_path,)
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
import comfy.model_management as mm
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
||||||
|
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
||||||
|
soft_empty_cache, resolve_hf_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# HuggingFace repo for auto-download
|
||||||
|
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
||||||
|
REQUIRED_FILES = {
|
||||||
|
"diffusion": "prismaudio.ckpt",
|
||||||
|
"vae": "vae.ckpt",
|
||||||
|
"synchformer": "synchformer_state_dict.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download_if_missing(filename, model_dir, hf_token=None):
|
||||||
|
"""Download a model file from HuggingFace if not present locally."""
|
||||||
|
filepath = os.path.join(model_dir, filename)
|
||||||
|
if os.path.exists(filepath):
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
||||||
|
try:
|
||||||
|
downloaded = hf_hub_download(
|
||||||
|
repo_id=HF_REPO_ID,
|
||||||
|
filename=filename,
|
||||||
|
local_dir=model_dir,
|
||||||
|
token=hf_token or None,
|
||||||
|
)
|
||||||
|
return downloaded
|
||||||
|
except Exception as e:
|
||||||
|
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
||||||
|
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
||||||
|
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioModelLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
register_model_folder()
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
||||||
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def load_model(self, precision, offload_strategy):
|
||||||
|
device = get_device()
|
||||||
|
dtype = determine_precision(precision, device)
|
||||||
|
strategy = determine_offload_strategy(offload_strategy)
|
||||||
|
token = resolve_hf_token()
|
||||||
|
model_dir = get_prismaudio_model_dir()
|
||||||
|
|
||||||
|
# Auto-download missing files
|
||||||
|
for key, filename in REQUIRED_FILES.items():
|
||||||
|
_download_if_missing(filename, model_dir, hf_token=token)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"prismaudio_core", "configs", "prismaudio.json"
|
||||||
|
)
|
||||||
|
with open(config_path) as f:
|
||||||
|
model_config = json.load(f)
|
||||||
|
|
||||||
|
# Create model from config
|
||||||
|
from prismaudio_core.factory import create_model_from_config
|
||||||
|
model = create_model_from_config(model_config)
|
||||||
|
|
||||||
|
# Load diffusion weights
|
||||||
|
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
||||||
|
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
||||||
|
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||||
|
if "state_dict" in diffusion_state:
|
||||||
|
diffusion_state = diffusion_state["state_dict"]
|
||||||
|
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
||||||
|
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
||||||
|
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
||||||
|
if diff_result.missing_keys:
|
||||||
|
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
||||||
|
if diff_result.unexpected_keys:
|
||||||
|
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
||||||
|
# Sample a few ckpt keys to verify prefix alignment
|
||||||
|
sample_keys = list(diffusion_state.keys())[:5]
|
||||||
|
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
||||||
|
|
||||||
|
# Load VAE weights separately
|
||||||
|
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||||
|
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||||
|
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||||
|
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
||||||
|
# Sample raw keys to see actual prefix
|
||||||
|
vae_sample_keys = list(vae_full_state.keys())[:8]
|
||||||
|
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
||||||
|
# Strip "autoencoder." prefix from keys
|
||||||
|
vae_state = {}
|
||||||
|
prefix = "autoencoder."
|
||||||
|
for k, v in vae_full_state.items():
|
||||||
|
if k.startswith(prefix):
|
||||||
|
vae_state[k[len(prefix):]] = v
|
||||||
|
else:
|
||||||
|
vae_state[k] = v
|
||||||
|
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
||||||
|
# Sample model keys to compare
|
||||||
|
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
||||||
|
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
||||||
|
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||||
|
# discriminator, loss modules, and EMA wrappers not present in the
|
||||||
|
# inference AudioAutoencoder — ignore those extra keys.
|
||||||
|
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
||||||
|
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
||||||
|
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
||||||
|
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
||||||
|
if vae_result.missing_keys:
|
||||||
|
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
||||||
|
|
||||||
|
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||||
|
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||||
|
model.model.to(dtype) # DiTWrapper
|
||||||
|
model.conditioner.to(dtype) # MultiConditioner
|
||||||
|
# model.pretransform stays in fp32
|
||||||
|
|
||||||
|
if strategy == "keep_in_vram":
|
||||||
|
model = model.to(device)
|
||||||
|
else:
|
||||||
|
model = model.to(get_offload_device())
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return ({
|
||||||
|
"model": model,
|
||||||
|
"dtype": dtype,
|
||||||
|
"strategy": strategy,
|
||||||
|
"config": model_config,
|
||||||
|
"model_dir": model_dir,
|
||||||
|
},)
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management as mm
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||||
|
get_device, get_offload_device, soft_empty_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioSampler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"features": ("PRISMAUDIO_FEATURES",),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
||||||
|
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
||||||
|
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||||
|
"sync_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.05, "tooltip": "Scale factor for sync conditioning. Higher values tighten audio-visual sync at the cost of audio naturalness; 0.0 disables sync guidance entirely."}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def generate(self, model, features, duration, steps, cfg_scale, sync_strength, seed):
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
diffusion = model["model"]
|
||||||
|
|
||||||
|
# Resolve duration: 0 means use video duration from features
|
||||||
|
if duration <= 0:
|
||||||
|
if "duration" not in features:
|
||||||
|
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
|
||||||
|
duration = features["duration"]
|
||||||
|
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
# Compute latent dimensions
|
||||||
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||||
|
|
||||||
|
# Sync temporal coverage diagnostic
|
||||||
|
sync_frames = features["sync_features"].shape[0]
|
||||||
|
sync_duration_covered = sync_frames / 25.0 # Synchformer always extracts at 25fps
|
||||||
|
print(f"[PrismAudio] sync: {sync_frames} frames @ 25fps = {sync_duration_covered:.2f}s | "
|
||||||
|
f"audio target: {latent_length} latent frames = {duration:.2f}s", flush=True)
|
||||||
|
if abs(sync_duration_covered - duration) > 0.5:
|
||||||
|
print(f"[PrismAudio] Warning: sync coverage ({sync_duration_covered:.2f}s) differs from "
|
||||||
|
f"audio duration ({duration:.2f}s) by more than 0.5s — consider re-extracting features "
|
||||||
|
f"with the correct video duration.", flush=True)
|
||||||
|
|
||||||
|
# Note: no seq length config needed — the model adapts to input tensor shapes
|
||||||
|
# dynamically via its transformer architecture.
|
||||||
|
|
||||||
|
# Determine if video features are present (not all zeros)
|
||||||
|
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||||
|
|
||||||
|
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||||
|
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||||
|
|
||||||
|
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||||
|
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||||
|
sample_meta = {
|
||||||
|
"video_features": video_feat,
|
||||||
|
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||||
|
"sync_features": sync_feat,
|
||||||
|
"video_exist": torch.tensor(has_video),
|
||||||
|
}
|
||||||
|
metadata = (sample_meta,)
|
||||||
|
|
||||||
|
# Move model to device if offloaded
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(device)
|
||||||
|
diffusion.conditioner.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
# Run conditioning
|
||||||
|
conditioning = diffusion.conditioner(metadata, device)
|
||||||
|
|
||||||
|
# Handle missing video: substitute learned empty embeddings
|
||||||
|
if not has_video:
|
||||||
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||||
|
|
||||||
|
# Scale sync conditioning after the conditioner MLP (clean linear scale,
|
||||||
|
# avoids SiLU nonlinearity in Sync_MLP). The CFG null path always uses zeros,
|
||||||
|
# so this directly scales the sync guidance magnitude: cfg_scale * (strength*cond - 0).
|
||||||
|
# Only applied when video is present — T2A uses learned empty_sync_feat, not raw sync.
|
||||||
|
if has_video and sync_strength != 1.0 and 'sync_features' in conditioning:
|
||||||
|
conditioning['sync_features'][0] = conditioning['sync_features'][0] * sync_strength
|
||||||
|
|
||||||
|
# Assemble conditioning inputs for the DiT
|
||||||
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||||
|
gen_device = "cpu" if device.type == "mps" else device
|
||||||
|
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||||
|
noise = torch.randn(
|
||||||
|
[1, IO_CHANNELS, latent_length],
|
||||||
|
generator=generator,
|
||||||
|
device=gen_device,
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Sample with progress bar
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||||
|
|
||||||
|
def on_step(info):
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
fakes = sample_discrete_euler(
|
||||||
|
diffusion.model,
|
||||||
|
noise,
|
||||||
|
steps,
|
||||||
|
callback=on_step,
|
||||||
|
**cond_inputs,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
batch_cfg=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
fakes_f = fakes.float()
|
||||||
|
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||||
|
|
||||||
|
# Offload diffusion model and conditioner before VAE decode
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(get_offload_device())
|
||||||
|
diffusion.conditioner.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
|
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||||
|
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||||
|
audio = diffusion.pretransform.decode(fakes_f)
|
||||||
|
|
||||||
|
# Offload VAE
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.pretransform.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
||||||
|
audio = audio.float()
|
||||||
|
pre_norm_std = audio.std().item()
|
||||||
|
pre_norm_peak = audio.abs().max().item()
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||||
|
|
||||||
|
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
||||||
|
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||||
|
|
||||||
|
|
||||||
|
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
||||||
|
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
|
||||||
|
|
||||||
|
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
|
||||||
|
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
|
||||||
|
near-zero activations, NOT the learned null signal the model was trained with.
|
||||||
|
|
||||||
|
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
||||||
|
"""
|
||||||
|
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
||||||
|
|
||||||
|
# Substitute video_features with learned empty_clip_feat
|
||||||
|
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
|
||||||
|
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
|
||||||
|
batch_size = conditioning['video_features'][0].shape[0]
|
||||||
|
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||||
|
conditioning['video_features'][0] = empty_expanded
|
||||||
|
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||||
|
|
||||||
|
# Substitute sync_features with learned empty_sync_feat
|
||||||
|
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
||||||
|
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
|
||||||
|
batch_size = conditioning['sync_features'][0].shape[0]
|
||||||
|
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||||
|
conditioning['sync_features'][0] = empty_expanded
|
||||||
|
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
"""SelVA Activation Steering Extractor.
|
|
||||||
|
|
||||||
Computes per-block steering vectors by running the frozen generator on the
|
|
||||||
training dataset and recording how target style's conditioning shifts the DiT hidden
|
|
||||||
states vs. empty/unconditional conditioning.
|
|
||||||
|
|
||||||
For each block i:
|
|
||||||
steering[i] = mean(latent_hidden | target style conditions)
|
|
||||||
- mean(latent_hidden | empty conditions)
|
|
||||||
|
|
||||||
The resulting vectors are injected at inference time (via SelVA Sampler's
|
|
||||||
steering_strength input) to nudge the denoising trajectory toward target style's
|
|
||||||
activation patterns without modifying any model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_lora_trainer import _prepare_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_activations(generator, conditions, latent, t_tensor):
|
|
||||||
"""Run one predict_flow call, collecting latent hidden states per block.
|
|
||||||
|
|
||||||
Returns a list of [seq, hidden_dim] float32 CPU tensors,
|
|
||||||
one per block (joint_blocks first, then fused_blocks).
|
|
||||||
"""
|
|
||||||
activations = []
|
|
||||||
|
|
||||||
def make_hook(is_joint):
|
|
||||||
def hook(module, input, output):
|
|
||||||
h = output[0] if is_joint else output
|
|
||||||
activations.append(h.detach().float().mean(0).cpu()) # [seq, hidden]
|
|
||||||
return hook
|
|
||||||
|
|
||||||
handles = []
|
|
||||||
for block in generator.joint_blocks:
|
|
||||||
handles.append(block.register_forward_hook(make_hook(is_joint=True)))
|
|
||||||
for block in generator.fused_blocks:
|
|
||||||
handles.append(block.register_forward_hook(make_hook(is_joint=False)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
generator.predict_flow(latent, t_tensor, conditions)
|
|
||||||
finally:
|
|
||||||
for h in handles:
|
|
||||||
h.remove()
|
|
||||||
|
|
||||||
return activations # list of n_blocks tensors [seq, hidden]
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaActivationSteeringExtractor:
|
|
||||||
"""Computes activation steering vectors from a training dataset.
|
|
||||||
|
|
||||||
Runs the frozen generator on N clips at random timesteps with both
|
|
||||||
target style-conditioned and empty-conditioned inputs, then saves the mean
|
|
||||||
difference per DiT block to a .pt file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "extract"
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("steering_path",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Path to saved steering_vectors.pt — load with SelVA Activation Steering Loader.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Computes per-block activation steering vectors: mean(target style activations) − "
|
|
||||||
"mean(empty activations) at each DiT block. Load the result with "
|
|
||||||
"SelVA Activation Steering Loader and connect to the Sampler."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"data_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory containing .npz feature files (same as LoRA/TI trainer).",
|
|
||||||
}),
|
|
||||||
"output_path": ("STRING", {
|
|
||||||
"default": "steering_vectors.pt",
|
|
||||||
"tooltip": "Where to save the steering vectors. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
"n_samples": ("INT", {
|
|
||||||
"default": 16, "min": 1, "max": 256,
|
|
||||||
"tooltip": "Number of clips to average over. More = more stable vectors, slower extraction.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract(self, model, data_dir, output_path, n_samples, seed):
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
|
|
||||||
data_dir = Path(data_dir.strip())
|
|
||||||
if not data_dir.is_absolute():
|
|
||||||
data_dir = Path(folder_paths.models_dir) / data_dir
|
|
||||||
if not data_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[Steering] data_dir not found: {data_dir}")
|
|
||||||
|
|
||||||
out_path = Path(output_path.strip())
|
|
||||||
if not out_path.is_absolute():
|
|
||||||
out_path = Path(folder_paths.get_output_directory()) / out_path
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[Steering] Extracting steering vectors n_samples={n_samples}", flush=True)
|
|
||||||
print(f"[Steering] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[Steering] output = {out_path}\n", flush=True)
|
|
||||||
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
generator = model["generator"]
|
|
||||||
generator.eval()
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
indices = random.choices(range(len(dataset)), k=n_samples)
|
|
||||||
|
|
||||||
n_blocks = len(generator.joint_blocks) + len(generator.fused_blocks)
|
|
||||||
style_sums = [None] * n_blocks
|
|
||||||
empty_sums = [None] * n_blocks
|
|
||||||
counts = [0] * n_blocks
|
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(n_samples)
|
|
||||||
|
|
||||||
for sample_i, clip_idx in enumerate(indices):
|
|
||||||
x1_cpu, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
|
||||||
|
|
||||||
clip_f = clip_f_cpu.to(device, dtype) # [1, T_clip, 1024]
|
|
||||||
sync_f = sync_f_cpu.to(device, dtype) # [1, T_sync, 768]
|
|
||||||
text_clip = text_clip_cpu.to(device, dtype) # [1, 77, 1024]
|
|
||||||
|
|
||||||
# x1 shape is [1, latent_seq_len, latent_dim] — dim 1 is the sequence length.
|
|
||||||
clip_latent_seq_len = x1_cpu.shape[1]
|
|
||||||
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=clip_latent_seq_len,
|
|
||||||
clip_seq_len=clip_f.shape[1],
|
|
||||||
sync_seq_len=sync_f.shape[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
conditions = generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
|
||||||
empty_conditions = generator.get_empty_conditions(bs=1)
|
|
||||||
|
|
||||||
# Random timestep and noise latent for this clip
|
|
||||||
t_val = torch.rand(1).item()
|
|
||||||
t_tensor = torch.tensor([t_val], device=device, dtype=dtype)
|
|
||||||
latent = torch.randn(
|
|
||||||
1, clip_latent_seq_len, generator.latent_dim,
|
|
||||||
device=device, dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
style_acts = _collect_activations(generator, conditions, latent, t_tensor)
|
|
||||||
empty_acts = _collect_activations(generator, empty_conditions, latent, t_tensor)
|
|
||||||
|
|
||||||
for i, (st, em) in enumerate(zip(style_acts, empty_acts)):
|
|
||||||
if style_sums[i] is None:
|
|
||||||
style_sums[i] = st.clone()
|
|
||||||
empty_sums[i] = em.clone()
|
|
||||||
else:
|
|
||||||
style_sums[i] += st
|
|
||||||
empty_sums[i] += em
|
|
||||||
counts[i] += 1
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
if (sample_i + 1) % 4 == 0 or sample_i == n_samples - 1:
|
|
||||||
print(f"[Steering] Processed {sample_i + 1}/{n_samples} clips", flush=True)
|
|
||||||
|
|
||||||
# Steering vector per block: mean(target style) - mean(empty)
|
|
||||||
steering_vectors = []
|
|
||||||
for i in range(n_blocks):
|
|
||||||
vec = (style_sums[i] - empty_sums[i]) / counts[i] # [hidden]
|
|
||||||
steering_vectors.append(vec)
|
|
||||||
|
|
||||||
norm = vec.norm().item()
|
|
||||||
print(f"[Steering] Block {i:2d} steering_norm={norm:.4f}", flush=True)
|
|
||||||
|
|
||||||
n_joint = len(generator.joint_blocks)
|
|
||||||
payload = {
|
|
||||||
"steering_vectors": steering_vectors, # list of [seq, hidden] tensors
|
|
||||||
"n_blocks": n_blocks,
|
|
||||||
"n_joint": n_joint,
|
|
||||||
"n_fused": len(generator.fused_blocks),
|
|
||||||
"latent_seq_len": seq_cfg.latent_seq_len,
|
|
||||||
"n_samples": n_samples,
|
|
||||||
"seed": seed,
|
|
||||||
"mode": model["mode"],
|
|
||||||
"variant": model["variant"],
|
|
||||||
}
|
|
||||||
torch.save(payload, str(out_path))
|
|
||||||
print(f"\n[Steering] Saved: {out_path}", flush=True)
|
|
||||||
|
|
||||||
soft_empty_cache()
|
|
||||||
return (str(out_path),)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""SelVA Activation Steering Loader.
|
|
||||||
|
|
||||||
Loads a steering_vectors.pt bundle produced by SelVA Activation Steering Extractor
|
|
||||||
and returns a STEERING_VECTORS dict for use by SelVA Sampler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaActivationSteeringLoader:
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "load"
|
|
||||||
RETURN_TYPES = ("STEERING_VECTORS",)
|
|
||||||
RETURN_NAMES = ("steering_vectors",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Steering vectors bundle — connect to SelVA Sampler's steering_vectors input.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads activation steering vectors from a .pt file produced by "
|
|
||||||
"SelVA Activation Steering Extractor. Connect to SelVA Sampler to nudge "
|
|
||||||
"denoising toward the target activation patterns."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "steering_vectors.pt",
|
|
||||||
"tooltip": "Path to steering_vectors.pt. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def load(self, path):
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[Steering] File not found: {p}")
|
|
||||||
|
|
||||||
payload = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
n_blocks = payload["n_blocks"]
|
|
||||||
n_joint = payload["n_joint"]
|
|
||||||
n_fused = payload["n_fused"]
|
|
||||||
n_vecs = len(payload["steering_vectors"])
|
|
||||||
|
|
||||||
print(f"[Steering] Loaded: {p}", flush=True)
|
|
||||||
print(f"[Steering] blocks={n_blocks} (joint={n_joint} fused={n_fused}) "
|
|
||||||
f"latent_seq_len={payload['latent_seq_len']} "
|
|
||||||
f"n_samples={payload['n_samples']}", flush=True)
|
|
||||||
print(f"[Steering] mode={payload.get('mode')} variant={payload.get('variant')}", flush=True)
|
|
||||||
|
|
||||||
norms = [payload["steering_vectors"][i].norm().item() for i in range(n_vecs)]
|
|
||||||
mean_norm = sum(norms) / len(norms)
|
|
||||||
print(f"[Steering] Mean steering norm across {n_vecs} blocks: {mean_norm:.4f}", flush=True)
|
|
||||||
|
|
||||||
return (payload,)
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
"""SelVA Audio Post-Processing nodes.
|
|
||||||
|
|
||||||
Post-generation enhancement applied to standard AUDIO outputs:
|
|
||||||
SelvaHarmonicExciter — multi-band harmonic exciter (HPF → tanh → mix)
|
|
||||||
SelvaOutputNormalizer — LUFS normalization + true peak limiting
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaHarmonicExciter:
|
|
||||||
"""Multi-band harmonic exciter for post-generation enhancement.
|
|
||||||
|
|
||||||
Isolates high-frequency content above a cutoff, applies tanh saturation
|
|
||||||
to generate 2nd/3rd harmonics, then mixes back with the dry signal.
|
|
||||||
Restores harmonic richness lost during BigVGAN vocoder reconstruction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 3000.0, "min": 500.0, "max": 16000.0, "step": 100.0,
|
|
||||||
"tooltip": "Highpass cutoff frequency in Hz. Only content above this is excited. "
|
|
||||||
"3000 Hz targets the upper harmonics BigVGAN tends to smear.",
|
|
||||||
}),
|
|
||||||
"drive": ("FLOAT", {
|
|
||||||
"default": 2.0, "min": 1.0, "max": 10.0, "step": 0.5,
|
|
||||||
"tooltip": "Saturation drive. Higher = more harmonics generated. "
|
|
||||||
"2-3 is subtle, 5+ is aggressive.",
|
|
||||||
}),
|
|
||||||
"mix": ("FLOAT", {
|
|
||||||
"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Wet/dry blend. 0.1-0.2 is subtle enhancement, "
|
|
||||||
"0.5+ is aggressive harmonic addition.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "excite"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Multi-band harmonic exciter. Applies tanh saturation to the high-frequency band "
|
|
||||||
"to restore harmonics lost during BigVGAN vocoder reconstruction. "
|
|
||||||
"Uses pedalboard.HighpassFilter for band isolation."
|
|
||||||
)
|
|
||||||
|
|
||||||
def excite(self, audio, cutoff_hz: float, drive: float, mix: float):
|
|
||||||
from pedalboard import Pedalboard, HighpassFilter
|
|
||||||
|
|
||||||
wav = audio["waveform"][0] # [C, T]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
wav_np = wav.float().numpy() # [C, T]
|
|
||||||
|
|
||||||
# Isolate HF band
|
|
||||||
board = Pedalboard([HighpassFilter(cutoff_frequency_hz=cutoff_hz)])
|
|
||||||
hf = board(wav_np, sr) # [C, T]
|
|
||||||
|
|
||||||
# Tanh saturation — normalize by drive so output stays in [-1, 1]
|
|
||||||
excited = np.tanh(hf * drive) / max(drive, 1.0)
|
|
||||||
|
|
||||||
# Mix back with dry
|
|
||||||
mixed = wav_np + mix * excited
|
|
||||||
|
|
||||||
# Soft clip to prevent going over
|
|
||||||
mixed = np.tanh(mixed)
|
|
||||||
|
|
||||||
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, T]
|
|
||||||
print(
|
|
||||||
f"[HarmonicExciter] cutoff={cutoff_hz}Hz drive={drive} mix={mix:.0%}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return ({"waveform": wav_out, "sample_rate": sr},)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaOutputNormalizer:
|
|
||||||
"""Normalize generated audio to a target LUFS level with true peak limiting.
|
|
||||||
|
|
||||||
Apply as the final node before saving — brings generated audio to a
|
|
||||||
consistent loudness target regardless of input video loudness variation.
|
|
||||||
Uses pyloudnorm (BS.1770-4).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -14.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. "
|
|
||||||
"-14 LUFS for streaming (Spotify/YouTube), "
|
|
||||||
"-9 to -7 for production masters.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize output audio to a target LUFS level (BS.1770-4) with true peak limiting. "
|
|
||||||
"Apply as the last node before saving. Uses pyloudnorm."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, audio, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
wav = audio["waveform"][0] # [C, T]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [T, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [T] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
print("[OutputNormalizer] Could not measure loudness — clip too short or silent. Passing through.", flush=True)
|
|
||||||
return (audio,)
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_out = wav * gain_linear
|
|
||||||
|
|
||||||
peak = wav_out.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_out = wav_out * (tp_linear / peak)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[OutputNormalizer] {loudness:.1f} LUFS → {target_lufs} LUFS "
|
|
||||||
f"gain={gain_db:+.1f}dB TP={true_peak_dbtp}dBTP",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return ({"waveform": wav_out.unsqueeze(0), "sample_rate": sr},)
|
|
||||||
@@ -1,293 +0,0 @@
|
|||||||
"""SelVA Audio Preprocessors — condition training clips for codec compatibility.
|
|
||||||
|
|
||||||
Two nodes that reduce the domain mismatch between custom training audio and the
|
|
||||||
MMAudio VAE's expected spectral distribution, improving LoRA training quality:
|
|
||||||
|
|
||||||
SelvaHfSmoother — soft low-pass blend to attenuate extreme HF content
|
|
||||||
SelvaSpectralMatcher — adaptive per-band EQ toward the codec's training distribution
|
|
||||||
|
|
||||||
Root cause they address: MMAudio was trained on natural sounds (speech, foley, env)
|
|
||||||
with limited engineered HF content. The BigVGANv2 vocoder (frozen, pre-trained) handles
|
|
||||||
the codec's HF reconstruction poorly for sound design / music training clips, because
|
|
||||||
those clips land in a latent-space region the vocoder never saw during training.
|
|
||||||
|
|
||||||
Recommended order: SpectralMatcher → HfSmoother → feature extraction → LoRA training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio.functional as AF
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
# ── Mel filterbank (same algorithm as selva_core/ext/mel_converter.py) ────────
|
|
||||||
|
|
||||||
def _mel_filterbank(sr: int, n_fft: int, n_mels: int,
|
|
||||||
fmin: float, fmax: float) -> torch.Tensor:
|
|
||||||
"""Returns mel filterbank matrix [n_mels, n_fft//2+1]."""
|
|
||||||
def hz_to_mel(f):
|
|
||||||
return 2595.0 * np.log10(1.0 + np.asarray(f) / 700.0)
|
|
||||||
|
|
||||||
def mel_to_hz(m):
|
|
||||||
return 700.0 * (10.0 ** (np.asarray(m) / 2595.0) - 1.0)
|
|
||||||
|
|
||||||
n_freqs = n_fft // 2 + 1
|
|
||||||
fft_freqs = np.linspace(0.0, sr / 2.0, n_freqs)
|
|
||||||
mel_pts = np.linspace(hz_to_mel(fmin), hz_to_mel(fmax), n_mels + 2)
|
|
||||||
hz_pts = mel_to_hz(mel_pts)
|
|
||||||
|
|
||||||
fb = np.zeros((n_mels, n_freqs), dtype=np.float32)
|
|
||||||
for m in range(1, n_mels + 1):
|
|
||||||
lo, mid, hi = hz_pts[m - 1], hz_pts[m], hz_pts[m + 1]
|
|
||||||
up = (fft_freqs - lo) / (mid - lo + 1e-12)
|
|
||||||
down = (hi - fft_freqs) / (hi - mid + 1e-12)
|
|
||||||
fb[m - 1] = np.maximum(0.0, np.minimum(up, down))
|
|
||||||
return torch.from_numpy(fb)
|
|
||||||
|
|
||||||
|
|
||||||
# ── VAE target log-mel means (source: selva_core/ext/autoencoder/vae.py) ──────
|
|
||||||
# These are the per-band expected log-mel energy means from MMAudio's training data.
|
|
||||||
# Used as the spectral matching target: clips are EQ'd to match this profile.
|
|
||||||
|
|
||||||
_TARGET_MEAN_80D = [
|
|
||||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439,
|
|
||||||
-1.2922, -1.2927, -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912,
|
|
||||||
-1.4313, -1.4152, -1.4527, -1.4728, -1.4568, -1.5101, -1.5051, -1.5172,
|
|
||||||
-1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, -1.6081, -1.6331,
|
|
||||||
-1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
|
||||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377,
|
|
||||||
-1.8417, -1.8643, -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673,
|
|
||||||
-1.9824, -2.0042, -2.0215, -2.0436, -2.0766, -2.1064, -2.1418, -2.1855,
|
|
||||||
-2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, -2.4659, -2.5072,
|
|
||||||
-2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673,
|
|
||||||
]
|
|
||||||
|
|
||||||
_TARGET_MEAN_128D = [
|
|
||||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006,
|
|
||||||
-2.2357, -2.4597, -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047,
|
|
||||||
-2.7483, -2.5926, -2.7462, -2.7033, -2.7386, -2.8112, -2.7502, -2.9594,
|
|
||||||
-2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, -3.1191, -2.9893,
|
|
||||||
-3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
|
||||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509,
|
|
||||||
-3.5089, -3.4647, -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747,
|
|
||||||
-3.7072, -3.7279, -3.7283, -3.7795, -3.8259, -3.8447, -3.8663, -3.9182,
|
|
||||||
-3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, -4.1488, -4.1874,
|
|
||||||
-4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
|
||||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053,
|
|
||||||
-5.4927, -5.5712, -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103,
|
|
||||||
-6.0955, -6.1673, -6.2362, -6.3120, -6.3926, -6.4797, -6.5565, -6.6511,
|
|
||||||
-6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, -7.6136, -7.7469,
|
|
||||||
-7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
|
||||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861,
|
|
||||||
]
|
|
||||||
|
|
||||||
_MEL_CONFIGS = {
|
|
||||||
"16k": dict(sr=16_000, n_fft=1024, n_mels=80, hop=256, fmin=0, fmax=8_000,
|
|
||||||
target=_TARGET_MEAN_80D, log10=True),
|
|
||||||
"44k": dict(sr=44_100, n_fft=2048, n_mels=128, hop=512, fmin=0, fmax=22_050,
|
|
||||||
target=_TARGET_MEAN_128D, log10=False),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Node 1: HF Smoother ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaHfSmoother:
|
|
||||||
"""Soft high-frequency attenuation for LoRA training clip preprocessing.
|
|
||||||
|
|
||||||
Blends a low-pass filtered copy of the audio with the original. Attenuates
|
|
||||||
the extreme HF content common in engineered sound design that the BigVGANv2
|
|
||||||
vocoder handles poorly, bringing the clip closer to the spectral region the
|
|
||||||
MMAudio codec was trained on (natural sounds with limited HF energy).
|
|
||||||
|
|
||||||
A blend of 0.7 at 12 kHz is a transparent starting point — audible only on
|
|
||||||
close comparison. Increase blend or lower cutoff if roundtrip quality is still
|
|
||||||
poor after spectral matching.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
|
|
||||||
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
|
|
||||||
}),
|
|
||||||
"blend": ("FLOAT", {
|
|
||||||
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = original, 1 = fully filtered. 0.7 is a transparent starting point.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Blends a low-pass filtered version of the audio with the original to gently attenuate "
|
|
||||||
"high-frequency content that the SelVA codec handles poorly. "
|
|
||||||
"Use before feature extraction to improve LoRA training targets. "
|
|
||||||
"Run after SelVA Spectral Matcher for best results."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, audio, cutoff_hz: float, blend: float):
|
|
||||||
waveform = audio["waveform"].float() # [1, C, L]
|
|
||||||
sr = audio["sample_rate"]
|
|
||||||
|
|
||||||
filtered = AF.lowpass_biquad(waveform, sr, cutoff_hz)
|
|
||||||
out = blend * filtered + (1.0 - blend) * waveform
|
|
||||||
|
|
||||||
# Preserve RMS level — LPF removes energy, keep the clip at its original loudness
|
|
||||||
rms_in = waveform.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
rms_out = out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
out = out * (rms_in / rms_out)
|
|
||||||
|
|
||||||
peak = out.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
out = out / peak
|
|
||||||
|
|
||||||
print(f"[HF Smoother] cutoff={cutoff_hz:.0f} Hz blend={blend:.2f} "
|
|
||||||
f"rms={rms_in:.4f}→{out.pow(2).mean().sqrt():.4f} "
|
|
||||||
f"peak={out.abs().max():.4f}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": sr},)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Node 2: Spectral Matcher ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaSpectralMatcher:
|
|
||||||
"""Adaptive per-band EQ toward the SelVA VAE's expected spectral distribution.
|
|
||||||
|
|
||||||
Computes the log-mel energy profile of the clip and compares it to the per-band
|
|
||||||
means stored in the VAE's normalization buffers (the statistics MMAudio was trained
|
|
||||||
on). Applies a smooth frequency-domain gain correction so the clip's spectral shape
|
|
||||||
matches what the codec expects, improving encode→decode roundtrip quality and
|
|
||||||
therefore LoRA training target quality.
|
|
||||||
|
|
||||||
The correction is additive in log space (multiplicative in linear), so it only
|
|
||||||
changes spectral balance — not the waveform's timing or phase structure.
|
|
||||||
|
|
||||||
max_gain_db clamps the correction to prevent extreme boosts on very quiet bands.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
"mode": (["44k", "16k"], {
|
|
||||||
"tooltip": "Must match the SelVA model you are training. "
|
|
||||||
"44k = large model, 16k = small model.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = no correction, 1 = full match to VAE distribution. "
|
|
||||||
"0.8 is a good starting point.",
|
|
||||||
}),
|
|
||||||
"max_gain_db": ("FLOAT", {
|
|
||||||
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
|
|
||||||
"tooltip": "Clamps per-band gain to ±dB. Prevents extreme boosts on "
|
|
||||||
"very quiet frequency bands. 12 dB is conservative.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Applies a smooth per-band gain correction to bring the audio's spectral profile "
|
|
||||||
"in line with the MMAudio VAE's expected distribution, derived from the per-band "
|
|
||||||
"normalization statistics baked into the VAE weights. "
|
|
||||||
"Use before feature extraction to improve LoRA training target quality. "
|
|
||||||
"Run before SelVA HF Smoother."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, audio, mode: str, strength: float, max_gain_db: float):
|
|
||||||
cfg = _MEL_CONFIGS[mode]
|
|
||||||
waveform = audio["waveform"].float() # [1, C, L]
|
|
||||||
sr_in = audio["sample_rate"]
|
|
||||||
sr_tgt = cfg["sr"]
|
|
||||||
n_fft = cfg["n_fft"]
|
|
||||||
hop = cfg["hop"]
|
|
||||||
|
|
||||||
# ── flatten to mono and resample if needed ────────────────────────────
|
|
||||||
wav = waveform[0].mean(0) # [L]
|
|
||||||
if sr_in != sr_tgt:
|
|
||||||
wav = AF.resample(wav.unsqueeze(0), sr_in, sr_tgt).squeeze(0)
|
|
||||||
|
|
||||||
device = wav.device
|
|
||||||
window = torch.hann_window(n_fft, device=device)
|
|
||||||
|
|
||||||
# ── STFT ──────────────────────────────────────────────────────────────
|
|
||||||
stft = torch.stft(wav, n_fft, hop_length=hop, win_length=n_fft,
|
|
||||||
window=window, center=True, return_complex=True) # [n_freqs, T]
|
|
||||||
mag = stft.abs() # [n_freqs, T]
|
|
||||||
|
|
||||||
# ── current log-mel mean per band ─────────────────────────────────────
|
|
||||||
fb = _mel_filterbank(sr_tgt, n_fft, cfg["n_mels"],
|
|
||||||
cfg["fmin"], cfg["fmax"]).to(device) # [n_mels, n_freqs]
|
|
||||||
|
|
||||||
mel_mag = torch.matmul(fb, mag).clamp(min=1e-5) # [n_mels, T]
|
|
||||||
if cfg["log10"]:
|
|
||||||
mel_log = torch.log10(mel_mag)
|
|
||||||
else:
|
|
||||||
mel_log = torch.log(mel_mag)
|
|
||||||
|
|
||||||
current_mean = mel_log.mean(dim=-1) # [n_mels]
|
|
||||||
target_mean = torch.tensor(cfg["target"], device=device) # [n_mels]
|
|
||||||
|
|
||||||
# ── per-mel-band gain (log space) ─────────────────────────────────────
|
|
||||||
mel_gain = (target_mean - current_mean) * strength # [n_mels]
|
|
||||||
|
|
||||||
# Clamp to ±max_gain_db
|
|
||||||
if cfg["log10"]:
|
|
||||||
max_log = max_gain_db / 20.0 # log10: 20 log10 = dB
|
|
||||||
else:
|
|
||||||
max_log = max_gain_db / 8.6859 # ln: 20 * log10(e) ≈ 8.686
|
|
||||||
mel_gain = mel_gain.clamp(-max_log, max_log)
|
|
||||||
|
|
||||||
# ── map mel gains → STFT frequency bins (weighted average) ────────────
|
|
||||||
fb_sum = fb.sum(0).clamp(min=1e-8) # [n_freqs]
|
|
||||||
freq_gain = (mel_gain @ fb) / fb_sum # [n_freqs]
|
|
||||||
|
|
||||||
if cfg["log10"]:
|
|
||||||
linear_gain = 10.0 ** freq_gain # [n_freqs]
|
|
||||||
else:
|
|
||||||
linear_gain = torch.exp(freq_gain) # [n_freqs]
|
|
||||||
|
|
||||||
# ── apply gain in frequency domain and reconstruct ───────────────────
|
|
||||||
stft_out = stft * linear_gain.unsqueeze(-1) # [n_freqs, T]
|
|
||||||
wav_out = torch.istft(stft_out, n_fft, hop_length=hop, win_length=n_fft,
|
|
||||||
window=window, center=True,
|
|
||||||
length=wav.shape[0]) # [L]
|
|
||||||
|
|
||||||
# ── resample back to original sr ──────────────────────────────────────
|
|
||||||
if sr_in != sr_tgt:
|
|
||||||
wav_out = AF.resample(wav_out.unsqueeze(0), sr_tgt, sr_in).squeeze(0)
|
|
||||||
|
|
||||||
# ── preserve original RMS level ───────────────────────────────────────
|
|
||||||
rms_in = wav.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
rms_out = wav_out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
wav_out = wav_out * (rms_in / rms_out)
|
|
||||||
|
|
||||||
peak = wav_out.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
wav_out = wav_out / peak
|
|
||||||
|
|
||||||
# ── reshape to match input layout [1, C, L] ───────────────────────────
|
|
||||||
out = wav_out.unsqueeze(0).unsqueeze(0)
|
|
||||||
if waveform.shape[1] > 1:
|
|
||||||
out = out.expand(-1, waveform.shape[1], -1).clone()
|
|
||||||
|
|
||||||
gain_db_range = (
|
|
||||||
20.0 * torch.log10(linear_gain.clamp(min=1e-8))
|
|
||||||
)
|
|
||||||
print(f"[Spectral Matcher] mode={mode} strength={strength:.2f} "
|
|
||||||
f"gain [{gain_db_range.min():.1f}, {gain_db_range.max():.1f}] dB "
|
|
||||||
f"rms={rms_in:.4f}→{out.pow(2).mean().sqrt():.4f}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": sr_in},)
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
"""SelVA BigVGAN Loader.
|
|
||||||
|
|
||||||
Loads a fine-tuned BigVGAN vocoder checkpoint produced by SelVA BigVGAN Trainer
|
|
||||||
and replaces the vocoder weights in the loaded SELVA_MODEL in-place.
|
|
||||||
|
|
||||||
The model is modified in-place so ComfyUI's model cache is updated — no need to
|
|
||||||
reload the full SelVA model. Subsequent Sampler runs will use the fine-tuned vocoder.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
from .selva_bigvgan_trainer import inject_gafilters
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaBigvganLoader:
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "load"
|
|
||||||
RETURN_TYPES = ("SELVA_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
OUTPUT_TOOLTIPS = ("SELVA_MODEL with the fine-tuned BigVGAN vocoder injected.",)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads a fine-tuned BigVGAN/BigVGANv2 vocoder checkpoint from SelVA BigVGAN Trainer "
|
|
||||||
"and replaces the vocoder weights in the SELVA_MODEL in-place. "
|
|
||||||
"Supports both 16k and 44k models. "
|
|
||||||
"Connect the output to SelVA Sampler instead of the base model loader."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "bigvgan_bj.pt",
|
|
||||||
"tooltip": "Path to fine-tuned vocoder checkpoint (.pt). "
|
|
||||||
"Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def load(self, model, path):
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[BigVGAN] Checkpoint not found: {p}")
|
|
||||||
|
|
||||||
ckpt = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
if "generator" not in ckpt:
|
|
||||||
raise ValueError(f"[BigVGAN] Expected {{'generator': ...}} in checkpoint, got keys: {list(ckpt.keys())}")
|
|
||||||
|
|
||||||
mode = model["mode"]
|
|
||||||
if mode == "16k":
|
|
||||||
vocoder = model["feature_utils"].tod.vocoder.vocoder # BigVGANVocoder
|
|
||||||
elif mode == "44k":
|
|
||||||
vocoder = model["feature_utils"].tod.vocoder # BigVGANv2 directly
|
|
||||||
else:
|
|
||||||
raise ValueError(f"[BigVGAN] Unknown mode: {mode}")
|
|
||||||
|
|
||||||
# Remember device before injecting new modules (which default to CPU)
|
|
||||||
target_device = next(vocoder.parameters()).device
|
|
||||||
|
|
||||||
if ckpt.get("has_gafilter", False):
|
|
||||||
kernel_size = ckpt.get("gafilter_kernel_size", 9)
|
|
||||||
n_gaf = inject_gafilters(vocoder, kernel_size)
|
|
||||||
print(f"[BigVGAN] GAFilter injected: {n_gaf} filters kernel={kernel_size}", flush=True)
|
|
||||||
|
|
||||||
vocoder.load_state_dict(ckpt["generator"])
|
|
||||||
vocoder.to(target_device)
|
|
||||||
vocoder.eval()
|
|
||||||
|
|
||||||
print(f"[BigVGAN] Loaded fine-tuned vocoder from: {p}", flush=True)
|
|
||||||
return (model,)
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
"""SelVA BigVGAN Vocoder Scheduler — runs a sweep of vocoder fine-tuning experiments.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
Audio clips are loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "bigvgan_sweep",
|
|
||||||
"description": "optional note",
|
|
||||||
"data_dir": "/path/to/audio/clips",
|
|
||||||
"output_root": "/path/to/output",
|
|
||||||
"base": { "train_mode": "snake_alpha_only", "steps": 2000, "lr": 1e-4, ... },
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "..."},
|
|
||||||
{"id": "all_5k", "train_mode": "all_params", "steps": 5000, "lr": 1e-5},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_bigvgan_trainer import (
|
|
||||||
_do_train,
|
|
||||||
_pregenerate_lora_mels,
|
|
||||||
_load_wav,
|
|
||||||
)
|
|
||||||
from .selva_lora_trainer import _smooth_losses, _pil_to_tensor
|
|
||||||
from .selva_lora_scheduler import (
|
|
||||||
_get_system_info,
|
|
||||||
_resolve_path,
|
|
||||||
_draw_comparison_curves,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Defaults mirror SelvaBigvganTrainer INPUT_TYPES defaults
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"train_mode": "snake_alpha_only",
|
|
||||||
"steps": 2000,
|
|
||||||
"lr": 1e-4,
|
|
||||||
"batch_size": 4,
|
|
||||||
"segment_seconds": 2.0,
|
|
||||||
"lambda_l2sp": 1e-3,
|
|
||||||
"use_gafilter": True,
|
|
||||||
"gafilter_kernel_size": 9,
|
|
||||||
"lambda_phase": 1.0,
|
|
||||||
"save_every": 500,
|
|
||||||
"seed": 42,
|
|
||||||
"discriminator_path": "",
|
|
||||||
"lora_adapter": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
"""Merge param defaults + file base + experiment overrides."""
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_training_log(log_path: Path) -> list:
|
|
||||||
"""Parse BigVGAN training CSV → list of total_loss values."""
|
|
||||||
losses = []
|
|
||||||
if not log_path.exists():
|
|
||||||
return losses
|
|
||||||
try:
|
|
||||||
with open(log_path) as f:
|
|
||||||
reader = csv.DictReader(f)
|
|
||||||
for row in reader:
|
|
||||||
losses.append(float(row["total_loss"]))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return losses
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
total_steps: int) -> dict:
|
|
||||||
"""Build {step: loss} at each save_every boundary.
|
|
||||||
|
|
||||||
Uses round-to-nearest to handle log_interval that doesn't divide
|
|
||||||
save_every evenly (e.g. steps=3000 → log_interval=150, save_every=1000).
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
for target in range(save_every, total_steps + 1, save_every):
|
|
||||||
# loss_history[i] = loss at step (i+1)*log_interval
|
|
||||||
idx = round(target / log_interval) - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaBigvganScheduler:
|
|
||||||
"""Runs a sweep of BigVGAN vocoder fine-tuning experiments from a JSON file.
|
|
||||||
|
|
||||||
Audio clips are loaded once and reused across all experiments. Each experiment
|
|
||||||
deep-copies the vocoder and trains independently. Results are written to
|
|
||||||
`experiment_summary.json` after every completed run so partial results are
|
|
||||||
preserved if the sweep is interrupted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — share this file to compare runs.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of BigVGAN vocoder fine-tuning experiments defined in a JSON sweep file. "
|
|
||||||
"Audio clips are loaded once and reused across all experiments. "
|
|
||||||
"Results (loss, config, checkpoint paths) are collected in experiment_summary.json."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "bigvgan_experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"models directory; absolute paths are used as-is."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[BigVGAN Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[BigVGAN Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[BigVGAN Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[BigVGAN Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"bigvgan_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
mode = model["mode"]
|
|
||||||
dtype = model["dtype"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
mel_converter = feature_utils.mel_converter
|
|
||||||
strategy = model["strategy"]
|
|
||||||
|
|
||||||
if mode == "16k":
|
|
||||||
original_vocoder = feature_utils.tod.vocoder.vocoder
|
|
||||||
sample_rate = 16_000
|
|
||||||
elif mode == "44k":
|
|
||||||
original_vocoder = feature_utils.tod.vocoder
|
|
||||||
sample_rate = 44_100
|
|
||||||
else:
|
|
||||||
raise ValueError(f"[BigVGAN Scheduler] Unknown mode: {mode}")
|
|
||||||
|
|
||||||
print(f"\n[BigVGAN Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[BigVGAN Scheduler] {description}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load audio clips once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Find minimum segment length across all experiments so we load enough
|
|
||||||
min_segment_seconds = float("inf")
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
min_segment_seconds = min(min_segment_seconds, float(cfg.get("segment_seconds", 2.0)))
|
|
||||||
min_segment_samples = int(min_segment_seconds * sample_rate)
|
|
||||||
|
|
||||||
audio_files = []
|
|
||||||
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg", "*.aac"):
|
|
||||||
audio_files.extend(data_dir.rglob(ext))
|
|
||||||
if not audio_files:
|
|
||||||
raise FileNotFoundError(f"[BigVGAN Scheduler] No audio files in {data_dir}")
|
|
||||||
|
|
||||||
print(f"[BigVGAN Scheduler] Loading {len(audio_files)} audio files...", flush=True)
|
|
||||||
clips = []
|
|
||||||
for af in audio_files:
|
|
||||||
try:
|
|
||||||
wav, sr = _load_wav(af)
|
|
||||||
if wav.shape[0] > 1:
|
|
||||||
wav = wav.mean(0, keepdim=True)
|
|
||||||
if sr != sample_rate:
|
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
|
||||||
wav = wav.squeeze(0) # [L]
|
|
||||||
if wav.shape[0] >= min_segment_samples:
|
|
||||||
clips.append(wav.cpu())
|
|
||||||
else:
|
|
||||||
print(f" [BigVGAN Scheduler] Skip {af.name}: "
|
|
||||||
f"shorter than {min_segment_seconds}s", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [BigVGAN Scheduler] Failed {af.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
if not clips:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[BigVGAN Scheduler] No usable clips (need audio >= {min_segment_seconds}s)"
|
|
||||||
)
|
|
||||||
print(f"[BigVGAN Scheduler] {len(clips)} clips ready\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Offload unused components to free VRAM
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comfy.model_management.unload_all_models()
|
|
||||||
feature_utils.to("cpu")
|
|
||||||
if "generator" in model:
|
|
||||||
model["generator"].to("cpu")
|
|
||||||
if "video_enc" in model:
|
|
||||||
model["video_enc"].to("cpu")
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Pre-compute text CLIP embeddings if any experiment uses LoRA
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
text_clip_cache = {}
|
|
||||||
any_lora = any(
|
|
||||||
_merge_config(base_cfg, exp).get("lora_adapter", "")
|
|
||||||
for exp in spec["experiments"]
|
|
||||||
)
|
|
||||||
if any_lora:
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
if npz_files:
|
|
||||||
prompt_map = {}
|
|
||||||
prompts_file = data_dir / "prompts.txt"
|
|
||||||
if prompts_file.exists():
|
|
||||||
for line in prompts_file.read_text(encoding="utf-8").splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
if "|" in line:
|
|
||||||
fname, prompt = line.split("|", 1)
|
|
||||||
prompt_map[fname.strip()] = prompt.strip()
|
|
||||||
default_prompt = data_dir.name
|
|
||||||
|
|
||||||
clip_model = feature_utils.clip_model
|
|
||||||
if clip_model is not None:
|
|
||||||
clip_model.to(device)
|
|
||||||
try:
|
|
||||||
for npz_path in npz_files:
|
|
||||||
data = dict(np.load(str(npz_path), allow_pickle=False))
|
|
||||||
prompt = prompt_map.get(
|
|
||||||
npz_path.name, data.get("prompt", default_prompt)
|
|
||||||
)
|
|
||||||
if isinstance(prompt, np.ndarray):
|
|
||||||
prompt = str(prompt)
|
|
||||||
tc = feature_utils.encode_text_clip([prompt])
|
|
||||||
text_clip_cache[npz_path.name] = tc.clone().detach().cpu()
|
|
||||||
finally:
|
|
||||||
if clip_model is not None:
|
|
||||||
clip_model.to("cpu")
|
|
||||||
soft_empty_cache()
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print(f"[BigVGAN Scheduler] Pre-encoded {len(text_clip_cache)} "
|
|
||||||
f"CLIP embeddings", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Build or restore the summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
lh = rec["results"].get("loss_history", [])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": lh,
|
|
||||||
"log_interval": rec["results"].get("log_interval", 100),
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[BigVGAN Scheduler] Resuming — skipping "
|
|
||||||
f"{len(completed_ids)} completed: "
|
|
||||||
f"{sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[BigVGAN Scheduler] Could not read existing summary "
|
|
||||||
f"({e}) — starting fresh", flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": len(clips),
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(
|
|
||||||
json.dumps(summary, indent=2), encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Compute total steps for progress bar
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
total_steps = 0
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
if exp["id"] not in completed_ids:
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
total_steps += int(cfg.get("steps", 2000))
|
|
||||||
pbar = comfy.utils.ProgressBar(max(total_steps, 1))
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 8. Run experiments in a worker thread
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# BigVGAN training requires a fresh thread because ComfyUI runs nodes
|
|
||||||
# inside torch.inference_mode(). inference_mode is thread-local — a
|
|
||||||
# new thread starts with it OFF, so all tensor operations produce
|
|
||||||
# normal autograd-compatible tensors.
|
|
||||||
_exc = [None]
|
|
||||||
|
|
||||||
def _worker():
|
|
||||||
try:
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[BigVGAN Scheduler] Skipping '{exp_id}' "
|
|
||||||
f"(already completed)", flush=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
# ── Extract experiment parameters ────────────────────
|
|
||||||
train_mode = str(cfg.get("train_mode", "snake_alpha_only"))
|
|
||||||
exp_steps = int(cfg.get("steps", 2000))
|
|
||||||
exp_lr = float(cfg.get("lr", 1e-4))
|
|
||||||
exp_bs = int(cfg.get("batch_size", 4))
|
|
||||||
exp_seg_s = float(cfg.get("segment_seconds", 2.0))
|
|
||||||
exp_l2sp = float(cfg.get("lambda_l2sp", 1e-3))
|
|
||||||
exp_gafilter = bool(cfg.get("use_gafilter", True))
|
|
||||||
exp_gaf_ks = int(cfg.get("gafilter_kernel_size", 9))
|
|
||||||
exp_phase = float(cfg.get("lambda_phase", 1.0))
|
|
||||||
exp_save = int(cfg.get("save_every", 500))
|
|
||||||
exp_seed = int(cfg.get("seed", 42))
|
|
||||||
exp_disc = str(cfg.get("discriminator_path", ""))
|
|
||||||
exp_lora = str(cfg.get("lora_adapter", ""))
|
|
||||||
|
|
||||||
segment_samples = int(exp_seg_s * sample_rate)
|
|
||||||
|
|
||||||
# Filter clips long enough for this experiment
|
|
||||||
exp_clips = [c for c in clips if c.shape[0] >= segment_samples]
|
|
||||||
if not exp_clips:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}' skipped: "
|
|
||||||
f"no clips >= {exp_seg_s}s", flush=True)
|
|
||||||
summary["experiments"].append({
|
|
||||||
"id": exp_id, "description": exp_desc,
|
|
||||||
"config": dict(cfg),
|
|
||||||
"results": {
|
|
||||||
"status": "failed",
|
|
||||||
"error": f"No clips >= {exp_seg_s}s",
|
|
||||||
"duration_seconds": 0,
|
|
||||||
},
|
|
||||||
"checkpoint_path": None,
|
|
||||||
"output_dir": str(output_root / exp_id),
|
|
||||||
})
|
|
||||||
_write_summary()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# ── Resolve discriminator path ───────────────────────
|
|
||||||
disc_path = None
|
|
||||||
if exp_disc:
|
|
||||||
disc_path = Path(exp_disc.strip())
|
|
||||||
if not disc_path.is_absolute():
|
|
||||||
disc_path = (
|
|
||||||
Path(folder_paths.get_output_directory()) / disc_path
|
|
||||||
)
|
|
||||||
if not disc_path.exists():
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"discriminator not found: {disc_path}",
|
|
||||||
flush=True)
|
|
||||||
disc_path = None
|
|
||||||
|
|
||||||
# ── Pre-generate LoRA mels (disk-cached) ─────────────
|
|
||||||
lora_mel_pairs = None
|
|
||||||
if exp_lora:
|
|
||||||
lora_path = Path(exp_lora.strip())
|
|
||||||
if not lora_path.is_absolute():
|
|
||||||
lora_path = Path(folder_paths.base_path) / lora_path
|
|
||||||
if lora_path.exists():
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
lora_mel_pairs = _pregenerate_lora_mels(
|
|
||||||
model, data_dir, str(lora_path),
|
|
||||||
device, dtype, sample_rate,
|
|
||||||
seq_cfg.duration, seed=exp_seed,
|
|
||||||
cache_dir=str(output_root),
|
|
||||||
text_clip_cache=text_clip_cache,
|
|
||||||
)
|
|
||||||
if not lora_mel_pairs:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"no LoRA mel pairs generated",
|
|
||||||
flush=True)
|
|
||||||
lora_mel_pairs = None
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
|
||||||
print(f"[BigVGAN Scheduler] '{exp_id}': "
|
|
||||||
f"LoRA adapter not found: {lora_path}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# ── Output dir ───────────────────────────────────────
|
|
||||||
exp_dir = output_root / exp_id
|
|
||||||
exp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_path = exp_dir / f"bigvgan_{exp_id}.pt"
|
|
||||||
|
|
||||||
print(f"\n[BigVGAN Scheduler] ── Experiment '{exp_id}' ──",
|
|
||||||
flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[BigVGAN Scheduler] {exp_desc}", flush=True)
|
|
||||||
print(f"[BigVGAN Scheduler] mode={train_mode} "
|
|
||||||
f"steps={exp_steps} lr={exp_lr} bs={exp_bs} "
|
|
||||||
f"seg={exp_seg_s}s gafilter={exp_gafilter} "
|
|
||||||
f"phase={exp_phase} l2sp={exp_l2sp}", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"train_mode": train_mode, "steps": exp_steps,
|
|
||||||
"lr": exp_lr, "batch_size": exp_bs,
|
|
||||||
"segment_seconds": exp_seg_s,
|
|
||||||
"lambda_l2sp": exp_l2sp,
|
|
||||||
"use_gafilter": exp_gafilter,
|
|
||||||
"gafilter_kernel_size": exp_gaf_ks,
|
|
||||||
"lambda_phase": exp_phase,
|
|
||||||
"save_every": exp_save, "seed": exp_seed,
|
|
||||||
"discriminator_path": exp_disc,
|
|
||||||
"lora_adapter": exp_lora,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"checkpoint_path": None,
|
|
||||||
"output_dir": str(exp_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
# Ensure mel_converter is on device for this experiment
|
|
||||||
mel_converter.to(device)
|
|
||||||
|
|
||||||
# Fresh vocoder copy — _do_train modifies it in-place
|
|
||||||
vocoder_copy = copy.deepcopy(original_vocoder)
|
|
||||||
|
|
||||||
checkpoint_path = _do_train(
|
|
||||||
vocoder_copy, mel_converter, exp_clips,
|
|
||||||
device, dtype, strategy, feature_utils,
|
|
||||||
segment_samples, sample_rate,
|
|
||||||
train_mode, exp_steps, exp_lr, exp_bs,
|
|
||||||
exp_l2sp, exp_gafilter, exp_gaf_ks,
|
|
||||||
exp_phase, exp_save, exp_seed,
|
|
||||||
out_path, disc_path, pbar,
|
|
||||||
lora_mel_pairs,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
|
|
||||||
# Parse training CSV for loss history
|
|
||||||
log_path = exp_dir / f"bigvgan_{exp_id}_training_log.csv"
|
|
||||||
loss_history = _parse_training_log(log_path)
|
|
||||||
log_interval = max(1, exp_steps // 20)
|
|
||||||
smoothed = (
|
|
||||||
_smooth_losses(loss_history)
|
|
||||||
if loss_history else []
|
|
||||||
)
|
|
||||||
|
|
||||||
final_loss = (
|
|
||||||
round(smoothed[-1], 6) if smoothed else None
|
|
||||||
)
|
|
||||||
min_loss = (
|
|
||||||
round(min(smoothed), 6) if smoothed else None
|
|
||||||
)
|
|
||||||
min_idx = (
|
|
||||||
smoothed.index(min(smoothed))
|
|
||||||
if smoothed else None
|
|
||||||
)
|
|
||||||
min_loss_step = (
|
|
||||||
(min_idx + 1) * log_interval
|
|
||||||
if min_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
loss_std = round(
|
|
||||||
float(np.std(loss_history[-quarter:])), 6
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
loss_std = None
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval,
|
|
||||||
exp_save, exp_steps,
|
|
||||||
),
|
|
||||||
"loss_history": [
|
|
||||||
round(v, 6) for v in loss_history
|
|
||||||
],
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["checkpoint_path"] = checkpoint_path
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[BigVGAN Scheduler] Experiment '{exp_id}' "
|
|
||||||
f"failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
# Clean up vocoder copy to free VRAM
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
_exc[0] = e
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_worker, daemon=True)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
if _exc[0] is not None:
|
|
||||||
raise _exc[0]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 9. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[BigVGAN Scheduler] Sweep complete. "
|
|
||||||
f"Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 10. Comparison image
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
comparison_img.save(str(output_root / "loss_comparison.png"))
|
|
||||||
comparison_tensor = _pil_to_tensor(comparison_img)
|
|
||||||
|
|
||||||
return (str(summary_path), comparison_tensor)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,106 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetBrowser:
|
|
||||||
"""Browse a dataset.json file entry by entry using an integer index.
|
|
||||||
|
|
||||||
Each entry in the JSON is expected to have:
|
|
||||||
- "path" : base path (no extension) — directory that holds frame images
|
|
||||||
- "label" : text description of the clip
|
|
||||||
|
|
||||||
Derived outputs:
|
|
||||||
- video_path : path + ".mp4"
|
|
||||||
- audio_path : path + ".wav"
|
|
||||||
- frames_dir : path (the directory itself, for image-sequence loaders)
|
|
||||||
- label : entry["label"]
|
|
||||||
- count : total number of entries in the file
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset_json": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute or ComfyUI-relative path to a dataset.json file.",
|
|
||||||
}),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0,
|
|
||||||
"min": 0,
|
|
||||||
"max": 9999,
|
|
||||||
"step": 1,
|
|
||||||
"tooltip": "Zero-based index of the entry to inspect.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("video_path", "audio_wav", "audio_flac", "features_path", "frames_dir", "mask_dir", "label", "max_index")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"path + '.mp4'",
|
|
||||||
"features/ + name + '.wav'",
|
|
||||||
"features/ + name + '.flac'",
|
|
||||||
"features/ + name + '.npz' (pre-extracted SelVA features)",
|
|
||||||
"path (image-sequence directory)",
|
|
||||||
"path + '_mask' (mask image-sequence directory)",
|
|
||||||
"Text label for this clip",
|
|
||||||
"count - 1 — wire to a primitive INT's max to constrain the index widget",
|
|
||||||
)
|
|
||||||
FUNCTION = "browse"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Reads a dataset.json produced by the SelVA dataset preparation pipeline "
|
|
||||||
"and exposes one entry at a time via an integer index. "
|
|
||||||
"Outputs the video path, audio path, frames directory, label, and total entry count."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-read the file every call so edits are picked up without restarting ComfyUI.
|
|
||||||
IS_CHANGED = classmethod(lambda cls, **_: float("nan"))
|
|
||||||
|
|
||||||
def browse(self, dataset_json: str, index: int):
|
|
||||||
p = Path(dataset_json.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.base_path) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA Dataset Browser] File not found: {p}")
|
|
||||||
|
|
||||||
with p.open("r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
if not isinstance(data, list) or len(data) == 0:
|
|
||||||
raise ValueError(f"[SelVA Dataset Browser] Expected a non-empty JSON array in {p}")
|
|
||||||
|
|
||||||
count = len(data)
|
|
||||||
if index >= count:
|
|
||||||
raise IndexError(
|
|
||||||
f"[SelVA Dataset Browser] index {index} is out of range "
|
|
||||||
f"(dataset has {count} entries, last index is {count - 1})"
|
|
||||||
)
|
|
||||||
entry = data[index]
|
|
||||||
|
|
||||||
base = entry["path"]
|
|
||||||
label = entry.get("label", "")
|
|
||||||
|
|
||||||
p_base = Path(base)
|
|
||||||
feat_base = str(p_base.parent / "features" / p_base.name)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[SelVA Dataset Browser] {index + 1}/{count} label='{label}' base={base}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
base + ".mp4",
|
|
||||||
feat_base + ".wav",
|
|
||||||
feat_base + ".flac",
|
|
||||||
feat_base + ".npz",
|
|
||||||
base,
|
|
||||||
base + "_mask",
|
|
||||||
label,
|
|
||||||
count - 1,
|
|
||||||
)
|
|
||||||
@@ -1,788 +0,0 @@
|
|||||||
"""SelVA Audio Dataset Pipeline — chainable in-memory preprocessing nodes.
|
|
||||||
|
|
||||||
Typical chain:
|
|
||||||
SelvaDatasetLoader
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetResampler (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetLUFSNormalizer (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetCompressor (optional)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetSpectralMatcher (optional — batch spectral EQ)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetHfSmoother (optional — batch HF attenuation)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetAugmenter (optional — gain/pitch/stretch variants)
|
|
||||||
↓ AUDIO_DATASET
|
|
||||||
SelvaDatasetInspector (optional)
|
|
||||||
↓ AUDIO_DATASET + STRING report
|
|
||||||
SelvaDatasetItemExtractor → AUDIO (bridges to save/preview nodes)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
# ComfyUI custom type name — passed between all dataset pipeline nodes
|
|
||||||
AUDIO_DATASET = "AUDIO_DATASET"
|
|
||||||
|
|
||||||
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aac", ".m4a"}
|
|
||||||
_SOUNDFILE_EXTS = {".wav", ".flac", ".ogg"} # handled natively without FFmpeg
|
|
||||||
|
|
||||||
|
|
||||||
def _load_audio(path: Path):
|
|
||||||
"""Load audio file. Uses soundfile for WAV/FLAC/OGG to avoid torchcodec/FFmpeg issues."""
|
|
||||||
if path.suffix.lower() in _SOUNDFILE_EXTS:
|
|
||||||
import soundfile as sf
|
|
||||||
wav_np, sr = sf.read(str(path), dtype="float32", always_2d=True) # [L, C]
|
|
||||||
wav = torch.from_numpy(wav_np).T.unsqueeze(0) # [1, C, L]
|
|
||||||
else:
|
|
||||||
wav, sr = torchaudio.load(str(path)) # [C, L]
|
|
||||||
wav = wav.unsqueeze(0).float() # [1, C, L]
|
|
||||||
return wav, sr
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetLoader:
|
|
||||||
"""Load all audio files in a folder into an in-memory AUDIO_DATASET."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"folder": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to folder containing audio files. Searched recursively.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Load all audio files from a folder into memory as an AUDIO_DATASET."
|
|
||||||
|
|
||||||
def load(self, folder: str):
|
|
||||||
folder = Path(folder.strip())
|
|
||||||
if not folder.exists():
|
|
||||||
raise FileNotFoundError(f"[DatasetLoader] Folder not found: {folder}")
|
|
||||||
|
|
||||||
files = [f for f in folder.rglob("*") if f.suffix.lower() in _AUDIO_EXTS]
|
|
||||||
if not files:
|
|
||||||
raise RuntimeError(f"[DatasetLoader] No audio files found in {folder}")
|
|
||||||
|
|
||||||
dataset = []
|
|
||||||
for f in sorted(files):
|
|
||||||
try:
|
|
||||||
wav, sr = _load_audio(f)
|
|
||||||
dataset.append({"waveform": wav, "sample_rate": sr, "name": f.stem})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[DatasetLoader] Skipping {f.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[DatasetLoader] Loaded {len(dataset)} clips from {folder}", flush=True)
|
|
||||||
return (dataset,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetResampler:
|
|
||||||
"""Resample all clips in a dataset to a target sample rate using soxr VHQ."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_sr": ("INT", {
|
|
||||||
"default": 44100, "min": 8000, "max": 192000,
|
|
||||||
"tooltip": "Target sample rate. 44100 for large SelVA model, 16000 for small.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "resample"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Resample all clips to target_sr using soxr VHQ. Skips clips already at target rate."
|
|
||||||
|
|
||||||
def resample(self, dataset, target_sr: int):
|
|
||||||
import soxr
|
|
||||||
|
|
||||||
out = []
|
|
||||||
changed = 0
|
|
||||||
for item in dataset:
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
if sr == target_sr:
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
# soxr expects [L, C] (time-first), float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
wav_rs = soxr.resample(wav_np, sr, target_sr, quality="VHQ")
|
|
||||||
wav_t = torch.from_numpy(wav_rs).float().permute(1, 0).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_t, "sample_rate": target_sr, "name": item["name"]})
|
|
||||||
changed += 1
|
|
||||||
|
|
||||||
print(f"[DatasetResampler] {changed}/{len(dataset)} clips resampled → {target_sr} Hz", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetLUFSNormalizer:
|
|
||||||
"""Normalize each clip to a target integrated LUFS level + true peak limit."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -23.0, "min": -40.0, "max": -6.0, "step": 0.5,
|
|
||||||
"tooltip": "Target integrated loudness in LUFS. -23 is EBU R128 standard.",
|
|
||||||
}),
|
|
||||||
"true_peak_dbtp": ("FLOAT", {
|
|
||||||
"default": -1.0, "min": -6.0, "max": 0.0, "step": 0.5,
|
|
||||||
"tooltip": "True peak ceiling in dBTP. Applied after LUFS gain.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "normalize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Normalize each clip to target_lufs (BS.1770-4) then apply a true peak ceiling. "
|
|
||||||
"Skips clips that are too short for LUFS measurement (< 0.4 s)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, dataset, target_lufs: float, true_peak_dbtp: float):
|
|
||||||
import pyloudnorm as pyln
|
|
||||||
|
|
||||||
tp_linear = 10.0 ** (true_peak_dbtp / 20.0)
|
|
||||||
out = []
|
|
||||||
skipped = 0
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pyloudnorm wants [L] mono or [L, C] multichannel, float64
|
|
||||||
wav_np = wav.permute(1, 0).double().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
meter = pyln.Meter(sr)
|
|
||||||
try:
|
|
||||||
loudness = meter.integrated_loudness(wav_np)
|
|
||||||
except Exception:
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not np.isfinite(loudness):
|
|
||||||
skipped += 1
|
|
||||||
out.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
gain_db = target_lufs - loudness
|
|
||||||
gain_linear = 10.0 ** (gain_db / 20.0)
|
|
||||||
|
|
||||||
wav_norm = wav * gain_linear
|
|
||||||
|
|
||||||
# True peak limit
|
|
||||||
peak = wav_norm.abs().max().item()
|
|
||||||
if peak > tp_linear:
|
|
||||||
wav_norm = wav_norm * (tp_linear / peak)
|
|
||||||
|
|
||||||
out.append({"waveform": wav_norm.unsqueeze(0), "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[LUFSNormalizer] {len(dataset) - skipped}/{len(dataset)} clips normalized "
|
|
||||||
f"target={target_lufs} LUFS TP={true_peak_dbtp} dBTP skipped={skipped}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetCompressor:
|
|
||||||
"""Apply mild parallel compression to reduce within-clip loudness variance.
|
|
||||||
|
|
||||||
Uses pedalboard.Compressor (2:1–3:1 ratio). Parallel (New York) style:
|
|
||||||
blends compressed signal with dry so transients are preserved while
|
|
||||||
the dynamic range is gently tightened. Apply after LUFS normalization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"threshold_db": ("FLOAT", {
|
|
||||||
"default": -18.0, "min": -40.0, "max": -6.0, "step": 1.0,
|
|
||||||
"tooltip": "Compression kicks in above this level. -18 dB is a safe starting point after LUFS normalization.",
|
|
||||||
}),
|
|
||||||
"ratio": ("FLOAT", {
|
|
||||||
"default": 2.5, "min": 1.5, "max": 4.0, "step": 0.5,
|
|
||||||
"tooltip": "Compression ratio. 2:1–3:1 is mild; stay below 4:1 to avoid pumping.",
|
|
||||||
}),
|
|
||||||
"attack_ms": ("FLOAT", {
|
|
||||||
"default": 10.0, "min": 1.0, "max": 100.0, "step": 1.0,
|
|
||||||
"tooltip": "Attack time in ms. Slower attack preserves transients.",
|
|
||||||
}),
|
|
||||||
"release_ms": ("FLOAT", {
|
|
||||||
"default": 100.0, "min": 20.0, "max": 500.0, "step": 10.0,
|
|
||||||
"tooltip": "Release time in ms.",
|
|
||||||
}),
|
|
||||||
"mix": ("FLOAT", {
|
|
||||||
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Parallel blend: 0.0 = dry only, 1.0 = fully compressed. 0.3–0.5 is typical.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "compress"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Mild parallel compression to reduce within-clip dynamic range. "
|
|
||||||
"Blends compressed signal with dry at the given mix ratio. "
|
|
||||||
"Apply after LUFS normalization."
|
|
||||||
)
|
|
||||||
|
|
||||||
def compress(self, dataset, threshold_db: float, ratio: float,
|
|
||||||
attack_ms: float, release_ms: float, mix: float):
|
|
||||||
from pedalboard import Compressor, Pedalboard
|
|
||||||
|
|
||||||
board = Pedalboard([Compressor(
|
|
||||||
threshold_db=threshold_db,
|
|
||||||
ratio=ratio,
|
|
||||||
attack_ms=attack_ms,
|
|
||||||
release_ms=release_ms,
|
|
||||||
)])
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# pedalboard expects [C, L] float32 numpy
|
|
||||||
wav_np = wav.float().numpy() # [C, L]
|
|
||||||
compressed = board(wav_np, sr) # [C, L]
|
|
||||||
mixed = (1.0 - mix) * wav_np + mix * compressed
|
|
||||||
wav_out = torch.from_numpy(mixed).unsqueeze(0) # [1, C, L]
|
|
||||||
out.append({"waveform": wav_out, "sample_rate": sr, "name": item["name"]})
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"[DatasetCompressor] {len(out)} clips compressed "
|
|
||||||
f"thr={threshold_db}dB ratio={ratio}:1 mix={mix:.0%}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_hf_shelf(wav: torch.Tensor, sr: int) -> bool:
|
|
||||||
"""Return True if clip looks codec-compressed (hard HF shelf above 15 kHz).
|
|
||||||
|
|
||||||
Method: compare mean energy in 1–5 kHz band vs 15–20 kHz band via STFT.
|
|
||||||
A ratio > 40 dB (i.e. near-silence above 15 kHz) flags codec artifacts.
|
|
||||||
"""
|
|
||||||
if sr < 32000:
|
|
||||||
return False # can't assess HF at low sample rates
|
|
||||||
|
|
||||||
n_fft = 2048
|
|
||||||
hop = 512
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
window = torch.hann_window(n_fft, device=mono.device)
|
|
||||||
stft = torch.stft(mono, n_fft, hop, n_fft, window, return_complex=True)
|
|
||||||
mag_sq = stft.abs().pow(2).mean(-1) # [n_freqs]
|
|
||||||
|
|
||||||
freqs = torch.linspace(0, sr / 2, n_fft // 2 + 1, device=mono.device)
|
|
||||||
band_lo = (freqs >= 1000) & (freqs < 5000)
|
|
||||||
band_hi = (freqs >= 15000) & (freqs < 20000)
|
|
||||||
|
|
||||||
if band_hi.sum() == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
energy_lo = mag_sq[band_lo].mean().clamp(min=1e-12)
|
|
||||||
energy_hi = mag_sq[band_hi].mean().clamp(min=1e-12)
|
|
||||||
ratio_db = 10.0 * torch.log10(energy_lo / energy_hi).item()
|
|
||||||
return ratio_db > 40.0
|
|
||||||
|
|
||||||
|
|
||||||
def _estimate_snr(wav: torch.Tensor) -> float:
|
|
||||||
"""Rough SNR estimate: ratio of 95th-percentile frame RMS to 5th-percentile frame RMS."""
|
|
||||||
mono = wav[0].mean(0) # [L]
|
|
||||||
if mono.shape[0] < 2048:
|
|
||||||
return 60.0 # clip too short to frame — assume clean
|
|
||||||
frames = mono.unfold(0, 2048, 512) # [N, 2048]
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt() # [N]
|
|
||||||
p95 = torch.quantile(rms, 0.95).item()
|
|
||||||
p05 = torch.quantile(rms, 0.05).clamp(min=1e-8).item()
|
|
||||||
return 20.0 * np.log10(p95 / p05 + 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetInspector:
|
|
||||||
"""Analyze each clip for quality issues and optionally filter out flagged clips."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"skip_rejected": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "If True, flagged clips are removed from the output dataset. "
|
|
||||||
"If False, all clips pass through but the report still lists issues.",
|
|
||||||
}),
|
|
||||||
"min_snr_db": ("FLOAT", {
|
|
||||||
"default": 15.0, "min": 0.0, "max": 60.0, "step": 1.0,
|
|
||||||
"tooltip": "Clips with estimated SNR below this value are flagged.",
|
|
||||||
}),
|
|
||||||
"check_codec_artifacts": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Flag clips with a hard HF shelf above 15 kHz (MP3/codec artifact signature).",
|
|
||||||
}),
|
|
||||||
"max_silence_fraction": ("FLOAT", {
|
|
||||||
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Flag clips where more than this fraction of frames are near-silent "
|
|
||||||
"(< -60 dBFS). Set to 0 to disable silence detection.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET, "STRING")
|
|
||||||
RETURN_NAMES = ("dataset", "report")
|
|
||||||
FUNCTION = "inspect"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Analyze each clip for clipping, low SNR, and codec artifacts. "
|
|
||||||
"Outputs a filtered AUDIO_DATASET and a text report. "
|
|
||||||
"Connect report to a ShowText node to preview in the UI."
|
|
||||||
)
|
|
||||||
|
|
||||||
def inspect(self, dataset, skip_rejected: bool, min_snr_db: float,
|
|
||||||
check_codec_artifacts: bool, max_silence_fraction: float = 0.5):
|
|
||||||
clean = []
|
|
||||||
flagged = []
|
|
||||||
lines = ["SelVA Dataset Inspector Report", "=" * 40]
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
issues = []
|
|
||||||
|
|
||||||
# Clipping
|
|
||||||
peak = wav.abs().max().item()
|
|
||||||
if peak > 0.99:
|
|
||||||
issues.append(f"clipping (peak={peak:.3f})")
|
|
||||||
|
|
||||||
# Low SNR
|
|
||||||
snr = _estimate_snr(wav)
|
|
||||||
if snr < min_snr_db:
|
|
||||||
issues.append(f"low SNR ({snr:.1f} dB < {min_snr_db} dB)")
|
|
||||||
|
|
||||||
# Codec artifacts
|
|
||||||
if check_codec_artifacts and _check_hf_shelf(wav, sr):
|
|
||||||
issues.append("codec artifact (HF shelf > 15 kHz)")
|
|
||||||
|
|
||||||
# Silence detection
|
|
||||||
if max_silence_fraction > 0:
|
|
||||||
mono = wav[0].mean(0)
|
|
||||||
if mono.shape[0] >= 2048:
|
|
||||||
frames = mono.unfold(0, 2048, 512)
|
|
||||||
rms = frames.pow(2).mean(-1).sqrt()
|
|
||||||
silent_frac = (rms < 1e-3).float().mean().item()
|
|
||||||
if silent_frac > max_silence_fraction:
|
|
||||||
issues.append(f"mostly silent ({silent_frac:.0%} < -60 dBFS)")
|
|
||||||
|
|
||||||
if issues:
|
|
||||||
flagged.append(name)
|
|
||||||
lines.append(f" FLAGGED {name}: {', '.join(issues)}")
|
|
||||||
if not skip_rejected:
|
|
||||||
clean.append(item)
|
|
||||||
else:
|
|
||||||
clean.append(item)
|
|
||||||
lines.append(f" OK {name}")
|
|
||||||
|
|
||||||
lines.append("=" * 40)
|
|
||||||
lines.append(
|
|
||||||
f"Total: {len(dataset)} Clean: {len(clean)} Flagged: {len(flagged)}"
|
|
||||||
+ (" (removed)" if skip_rejected else " (kept)")
|
|
||||||
)
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(f"[DatasetInspector]\n{report}", flush=True)
|
|
||||||
return (clean, report)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetItemExtractor:
|
|
||||||
"""Extract a single AUDIO item from an AUDIO_DATASET by index.
|
|
||||||
|
|
||||||
Bridges the dataset pipeline to any node that accepts a standard AUDIO
|
|
||||||
input — save audio, HF Smoother, Spectral Matcher, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"index": ("INT", {
|
|
||||||
"default": 0, "min": 0, "max": 9999,
|
|
||||||
"tooltip": "0-based index. Wraps around if index >= dataset length.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO", "STRING", "INT")
|
|
||||||
RETURN_NAMES = ("audio", "name", "total")
|
|
||||||
FUNCTION = "extract"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Extract one clip from an AUDIO_DATASET by index. "
|
|
||||||
"Returns standard AUDIO (compatible with all audio nodes), "
|
|
||||||
"the clip name, and the total dataset length."
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract(self, dataset, index: int):
|
|
||||||
if not dataset:
|
|
||||||
raise RuntimeError("[DatasetItemExtractor] Dataset is empty.")
|
|
||||||
idx = index % len(dataset)
|
|
||||||
item = dataset[idx]
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
print(
|
|
||||||
f"[DatasetItemExtractor] [{idx}/{len(dataset)-1}] {item['name']} "
|
|
||||||
f"sr={item['sample_rate']} shape={tuple(item['waveform'].shape)}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (audio, item["name"], len(dataset))
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetSaver:
|
|
||||||
"""Save all clips in an AUDIO_DATASET to disk as FLAC files.
|
|
||||||
|
|
||||||
Optionally copies matching .npz feature files from a source directory,
|
|
||||||
keeping FLAC/NPZ pairs in sync after the inspector has filtered clips.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"output_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Absolute path to output folder. Created if it does not exist.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"npz_source_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "If set, copies {name}.npz from this folder alongside each saved FLAC. "
|
|
||||||
"Missing NPZs are warned but do not abort the save.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("report",)
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
FUNCTION = "save"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Save every clip in an AUDIO_DATASET to output_dir as FLAC. "
|
|
||||||
"If npz_source_dir is provided, copies the matching .npz file for each clip — "
|
|
||||||
"so rejected clips never get their NPZ copied."
|
|
||||||
)
|
|
||||||
|
|
||||||
def save(self, dataset, output_dir: str, npz_source_dir: str = ""):
|
|
||||||
import shutil
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
out = Path(output_dir.strip())
|
|
||||||
out.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
npz_src = Path(npz_source_dir.strip()) if npz_source_dir.strip() else None
|
|
||||||
|
|
||||||
saved = 0
|
|
||||||
npz_copied = 0
|
|
||||||
npz_missing = []
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
name = item["name"]
|
|
||||||
wav = item["waveform"][0] # [C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
|
|
||||||
# soundfile wants [L] mono or [L, C] multichannel, float32
|
|
||||||
wav_np = wav.permute(1, 0).float().numpy() # [L, C]
|
|
||||||
if wav_np.shape[1] == 1:
|
|
||||||
wav_np = wav_np[:, 0] # [L] mono
|
|
||||||
|
|
||||||
flac_path = out / f"{name}.flac"
|
|
||||||
sf.write(str(flac_path), wav_np, sr, subtype="PCM_24")
|
|
||||||
saved += 1
|
|
||||||
|
|
||||||
if npz_src is not None:
|
|
||||||
# Augmented clips store their origin name — use it to find the .npz
|
|
||||||
lookup = item.get("origin_name", name)
|
|
||||||
npz_path = npz_src / f"{lookup}.npz"
|
|
||||||
if npz_path.exists():
|
|
||||||
shutil.copy2(str(npz_path), str(out / f"{name}.npz"))
|
|
||||||
npz_copied += 1
|
|
||||||
else:
|
|
||||||
npz_missing.append(name)
|
|
||||||
|
|
||||||
lines = [
|
|
||||||
f"[DatasetSaver] Saved {saved} clips → {out}",
|
|
||||||
]
|
|
||||||
if npz_src is not None:
|
|
||||||
lines.append(f" NPZ copied: {npz_copied} missing: {len(npz_missing)}")
|
|
||||||
for n in npz_missing:
|
|
||||||
lines.append(f" MISSING NPZ: {n}")
|
|
||||||
|
|
||||||
report = "\n".join(lines)
|
|
||||||
print(report, flush=True)
|
|
||||||
return (report,)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Batch wrappers for audio preprocessors ───────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaDatasetSpectralMatcher:
|
|
||||||
"""Apply SelVA Spectral Matcher to every clip in an AUDIO_DATASET.
|
|
||||||
|
|
||||||
Wraps SelvaSpectralMatcher so it works on batch datasets instead of
|
|
||||||
individual AUDIO items. Same parameters — see that node for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"mode": (["44k", "16k"], {
|
|
||||||
"tooltip": "Must match the SelVA model you are training. "
|
|
||||||
"44k = large model, 16k = small model.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = no correction, 1 = full match to VAE distribution.",
|
|
||||||
}),
|
|
||||||
"max_gain_db": ("FLOAT", {
|
|
||||||
"default": 12.0, "min": 1.0, "max": 30.0, "step": 1.0,
|
|
||||||
"tooltip": "Clamps per-band gain to ±dB.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Apply adaptive spectral matching to every clip in a dataset. "
|
|
||||||
"Batch version of SelVA Spectral Matcher — same per-band EQ toward the "
|
|
||||||
"VAE's expected distribution."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, dataset, mode: str, strength: float, max_gain_db: float):
|
|
||||||
from .selva_audio_preprocessors import SelvaSpectralMatcher
|
|
||||||
matcher = SelvaSpectralMatcher()
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
(result,) = matcher.process(audio, mode, strength, max_gain_db)
|
|
||||||
new_item = dict(item) # preserve origin_name and any extra keys
|
|
||||||
new_item["waveform"] = result["waveform"]
|
|
||||||
new_item["sample_rate"] = result["sample_rate"]
|
|
||||||
out.append(new_item)
|
|
||||||
print(f"[DatasetSpectralMatcher] {len(out)} clips processed "
|
|
||||||
f"mode={mode} strength={strength}", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDatasetHfSmoother:
|
|
||||||
"""Apply SelVA HF Smoother to every clip in an AUDIO_DATASET.
|
|
||||||
|
|
||||||
Wraps SelvaHfSmoother so it works on batch datasets instead of
|
|
||||||
individual AUDIO items. Same parameters — see that node for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"cutoff_hz": ("FLOAT", {
|
|
||||||
"default": 12000.0, "min": 2000.0, "max": 20000.0, "step": 500.0,
|
|
||||||
"tooltip": "Low-pass cutoff. 12 kHz is gentle; lower = more aggressive.",
|
|
||||||
}),
|
|
||||||
"blend": ("FLOAT", {
|
|
||||||
"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "0 = original, 1 = fully filtered.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "process"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Apply soft HF attenuation to every clip in a dataset. "
|
|
||||||
"Batch version of SelVA HF Smoother — blends a low-pass filtered copy "
|
|
||||||
"with the original to tame extreme HF content."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process(self, dataset, cutoff_hz: float, blend: float):
|
|
||||||
from .selva_audio_preprocessors import SelvaHfSmoother
|
|
||||||
smoother = SelvaHfSmoother()
|
|
||||||
out = []
|
|
||||||
for item in dataset:
|
|
||||||
audio = {"waveform": item["waveform"], "sample_rate": item["sample_rate"]}
|
|
||||||
(result,) = smoother.process(audio, cutoff_hz, blend)
|
|
||||||
new_item = dict(item) # preserve origin_name and any extra keys
|
|
||||||
new_item["waveform"] = result["waveform"]
|
|
||||||
new_item["sample_rate"] = result["sample_rate"]
|
|
||||||
out.append(new_item)
|
|
||||||
print(f"[DatasetHfSmoother] {len(out)} clips processed "
|
|
||||||
f"cutoff={cutoff_hz:.0f}Hz blend={blend:.2f}", flush=True)
|
|
||||||
return (out,)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Dataset augmenter ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class SelvaDatasetAugmenter:
|
|
||||||
"""Create augmented variants of each clip to expand a small dataset.
|
|
||||||
|
|
||||||
Supports gain variation (always available) and optionally pitch shift
|
|
||||||
and time stretch via audiomentations. Install audiomentations for the
|
|
||||||
full feature set: ``pip install audiomentations``
|
|
||||||
|
|
||||||
Each original clip produces ``variants_per_clip`` augmented copies.
|
|
||||||
Originals are kept by default (toggle ``keep_originals``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"dataset": (AUDIO_DATASET,),
|
|
||||||
"variants_per_clip": ("INT", {
|
|
||||||
"default": 2, "min": 1, "max": 20,
|
|
||||||
"tooltip": "Number of augmented copies per original clip.",
|
|
||||||
}),
|
|
||||||
"gain_range_db": ("FLOAT", {
|
|
||||||
"default": 3.0, "min": 0.0, "max": 12.0, "step": 0.5,
|
|
||||||
"tooltip": "Random gain ±dB applied to each variant. 3 dB is subtle.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"pitch_range_semitones": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 4.0, "step": 0.25,
|
|
||||||
"tooltip": "Random pitch shift ±semitones. Requires audiomentations. 0 = disabled.",
|
|
||||||
}),
|
|
||||||
"time_stretch_range": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 0.3, "step": 0.05,
|
|
||||||
"tooltip": "Random time stretch ±fraction (0.1 = 90%–110% speed). "
|
|
||||||
"Requires audiomentations. 0 = disabled.",
|
|
||||||
}),
|
|
||||||
"keep_originals": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Include the original unaugmented clips in the output.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (AUDIO_DATASET,)
|
|
||||||
RETURN_NAMES = ("dataset",)
|
|
||||||
FUNCTION = "augment"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Create augmented variants of each clip (gain, pitch, time stretch) "
|
|
||||||
"to expand small training datasets. Gain is always available; pitch and "
|
|
||||||
"time stretch require audiomentations (pip install audiomentations)."
|
|
||||||
)
|
|
||||||
|
|
||||||
def augment(self, dataset, variants_per_clip: int, gain_range_db: float,
|
|
||||||
seed: int, pitch_range_semitones: float = 0.0,
|
|
||||||
time_stretch_range: float = 0.0, keep_originals: bool = True):
|
|
||||||
rng = np.random.RandomState(seed)
|
|
||||||
|
|
||||||
# Try audiomentations for pitch/stretch
|
|
||||||
use_am = False
|
|
||||||
am_compose = None
|
|
||||||
needs_am = pitch_range_semitones > 0 or time_stretch_range > 0
|
|
||||||
if needs_am:
|
|
||||||
try:
|
|
||||||
import audiomentations as am
|
|
||||||
transforms = []
|
|
||||||
if pitch_range_semitones > 0:
|
|
||||||
transforms.append(am.PitchShift(
|
|
||||||
min_semitones=-pitch_range_semitones,
|
|
||||||
max_semitones=pitch_range_semitones,
|
|
||||||
p=0.5,
|
|
||||||
))
|
|
||||||
if time_stretch_range > 0:
|
|
||||||
transforms.append(am.TimeStretch(
|
|
||||||
min_rate=1.0 - time_stretch_range,
|
|
||||||
max_rate=1.0 + time_stretch_range,
|
|
||||||
leave_length_unchanged=True,
|
|
||||||
p=0.5,
|
|
||||||
))
|
|
||||||
am_compose = am.Compose(transforms)
|
|
||||||
use_am = True
|
|
||||||
except ImportError:
|
|
||||||
print("[DatasetAugmenter] audiomentations not installed — "
|
|
||||||
"pitch_shift and time_stretch disabled. "
|
|
||||||
"Install: pip install audiomentations", flush=True)
|
|
||||||
|
|
||||||
out = []
|
|
||||||
if keep_originals:
|
|
||||||
out.extend(dataset)
|
|
||||||
|
|
||||||
for item in dataset:
|
|
||||||
wav = item["waveform"] # [1, C, L]
|
|
||||||
sr = item["sample_rate"]
|
|
||||||
name = item["name"]
|
|
||||||
|
|
||||||
for v in range(variants_per_clip):
|
|
||||||
# Gain variation (always applied)
|
|
||||||
gain_db = rng.uniform(-gain_range_db, gain_range_db) if gain_range_db > 0 else 0.0
|
|
||||||
gain_lin = 10.0 ** (gain_db / 20.0)
|
|
||||||
wav_aug = wav * gain_lin
|
|
||||||
|
|
||||||
# Pitch/stretch via audiomentations
|
|
||||||
if use_am and am_compose is not None:
|
|
||||||
wav_np = wav_aug[0].numpy() # [C, L] float32
|
|
||||||
if wav_np.shape[0] == 1:
|
|
||||||
wav_np = wav_np[0] # [L] mono for audiomentations
|
|
||||||
wav_np = am_compose(samples=wav_np, sample_rate=sr)
|
|
||||||
if wav_np.ndim == 1:
|
|
||||||
wav_np = wav_np[np.newaxis, :] # back to [1, L]
|
|
||||||
wav_aug = torch.from_numpy(wav_np).unsqueeze(0) # [1, C, L]
|
|
||||||
|
|
||||||
# Prevent clipping
|
|
||||||
peak = wav_aug.abs().max()
|
|
||||||
if peak > 1.0:
|
|
||||||
wav_aug = wav_aug / peak
|
|
||||||
|
|
||||||
out.append({
|
|
||||||
"waveform": wav_aug,
|
|
||||||
"sample_rate": sr,
|
|
||||||
"name": f"{name}_aug{v:02d}",
|
|
||||||
"origin_name": name,
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"[DatasetAugmenter] {len(dataset)} originals → {len(out)} total clips "
|
|
||||||
f"gain=±{gain_range_db:.1f}dB"
|
|
||||||
+ (f" pitch=±{pitch_range_semitones:.1f}st" if pitch_range_semitones > 0 else "")
|
|
||||||
+ (f" stretch=±{time_stretch_range:.0%}" if time_stretch_range > 0 else ""),
|
|
||||||
flush=True)
|
|
||||||
return (out,)
|
|
||||||
@@ -1,515 +0,0 @@
|
|||||||
"""SelVA DITTO Optimizer.
|
|
||||||
|
|
||||||
Inference-time noise optimization: optimizes the initial noise latent x_0
|
|
||||||
using a style loss against target style reference clips, backpropagating through the
|
|
||||||
ODE solver. All model weights remain frozen — only x_0 changes.
|
|
||||||
|
|
||||||
Based on DITTO: Diffusion Inference-Time T-Optimization (arXiv:2401.12179,
|
|
||||||
ICML 2024 Oral). Adapted for SelVA's flow-matching Euler ODE.
|
|
||||||
|
|
||||||
Style loss: mel-spectrogram statistics matching (mean spectrum + Gram matrix)
|
|
||||||
against target style reference clips. Runs entirely before the vocoder — optimization
|
|
||||||
only requires the DiT + VAE decoder, not BigVGAN.
|
|
||||||
|
|
||||||
Memory strategy: gradient checkpointing at each ODE step — stores O(1 DiT
|
|
||||||
forward pass activations) instead of O(N steps). Backward recomputes each
|
|
||||||
step's activations on demand.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
|
||||||
|
|
||||||
|
|
||||||
def _load_wav(path):
|
|
||||||
"""Load audio file to [channels, samples] float32 tensor."""
|
|
||||||
try:
|
|
||||||
return torchaudio.load(str(path))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
import soundfile as sf
|
|
||||||
data, sr = sf.read(str(path), dtype="float32", always_2d=True)
|
|
||||||
wav = torch.from_numpy(data.T)
|
|
||||||
return wav, sr
|
|
||||||
|
|
||||||
|
|
||||||
def _mel_style_loss(mel_gen, ref_mean, ref_gram, gram_weight=0.0):
|
|
||||||
"""Style loss between generated mel and precomputed reference statistics.
|
|
||||||
|
|
||||||
mel_gen: [1, n_mels, T] generated mel spectrogram (with grad)
|
|
||||||
ref_mean: [n_mels] mean spectrum of reference clips (detached)
|
|
||||||
ref_gram: [n_mels, n_mels] Gram matrix of reference clips (detached)
|
|
||||||
gram_weight: weight for Gram matrix component — 0 = mean spectrum only.
|
|
||||||
Start at 0; enable only if mean-only optimization converges
|
|
||||||
without noise, then increase slowly (0.01–0.1).
|
|
||||||
"""
|
|
||||||
m = mel_gen.squeeze(0) # [n_mels, T]
|
|
||||||
|
|
||||||
# Mean spectrum loss — captures spectral envelope
|
|
||||||
gen_mean = m.mean(dim=-1) # [n_mels]
|
|
||||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
|
||||||
|
|
||||||
if gram_weight <= 0.0:
|
|
||||||
return loss_mean
|
|
||||||
|
|
||||||
# Gram matrix loss — captures timbral texture (can add noise if too high)
|
|
||||||
gram_gen = (m @ m.T) / m.shape[-1] # [n_mels, n_mels]
|
|
||||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
|
||||||
|
|
||||||
return loss_mean + gram_weight * loss_gram
|
|
||||||
|
|
||||||
|
|
||||||
def _latent_style_loss(z, ref_mean, ref_gram, gram_weight=0.0):
|
|
||||||
"""Style loss computed directly in VAE latent space.
|
|
||||||
|
|
||||||
z: [T_lat, C_lat] unnormalized latent at ODE endpoint (with grad)
|
|
||||||
ref_mean: [C_lat] mean latent vector of reference clips
|
|
||||||
ref_gram: [C_lat, C_lat] Gram matrix of reference latents
|
|
||||||
gram_weight: weight for Gram component — 0 = mean only (recommended start)
|
|
||||||
|
|
||||||
Operating in latent space avoids backprop through the VAE decoder, which
|
|
||||||
is @torch.inference_mode() and produces noisy, unstable gradients.
|
|
||||||
"""
|
|
||||||
# Mean latent loss — matches average activation per channel
|
|
||||||
gen_mean = z.mean(dim=0) # [C_lat]
|
|
||||||
loss_mean = F.l1_loss(gen_mean, ref_mean)
|
|
||||||
|
|
||||||
if gram_weight <= 0.0:
|
|
||||||
return loss_mean
|
|
||||||
|
|
||||||
# Gram matrix — inter-channel covariance, position-invariant
|
|
||||||
gram_gen = (z.T @ z) / z.shape[0] # [C_lat, C_lat]
|
|
||||||
loss_gram = F.mse_loss(gram_gen, ref_gram)
|
|
||||||
|
|
||||||
return loss_mean + gram_weight * loss_gram
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaDittoOptimizer:
|
|
||||||
"""DITTO inference-time noise optimization.
|
|
||||||
|
|
||||||
Freezes all model weights and optimizes only the initial noise latent x_0
|
|
||||||
to make the generated audio sound like the target style reference clips.
|
|
||||||
No training data or gradient updates to the model — per-video per-run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"features": ("SELVA_FEATURES",),
|
|
||||||
"prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": True,
|
|
||||||
"tooltip": "Sound description. Leave empty to use features prompt.",
|
|
||||||
}),
|
|
||||||
"negative_prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": False,
|
|
||||||
}),
|
|
||||||
"reference_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory with target style reference audio files (.wav/.flac/.mp3). "
|
|
||||||
"Reference mel statistics are precomputed from these once.",
|
|
||||||
}),
|
|
||||||
"n_opt_steps": ("INT", {
|
|
||||||
"default": 50, "min": 5, "max": 500,
|
|
||||||
"tooltip": "Gradient optimization steps on x_0. 50 is a good start; "
|
|
||||||
"each step requires ~2 DiT forward passes.",
|
|
||||||
}),
|
|
||||||
"opt_lr": ("FLOAT", {
|
|
||||||
"default": 0.02, "min": 0.001, "max": 2.0, "step": 0.001,
|
|
||||||
"tooltip": "Adam learning rate for x_0 optimization. "
|
|
||||||
"0.02–0.05 is recommended; 0.1 (paper default) causes oscillation.",
|
|
||||||
}),
|
|
||||||
"n_ode_steps": ("INT", {
|
|
||||||
"default": 10, "min": 5, "max": 50,
|
|
||||||
"tooltip": "Euler ODE steps run during each optimization iteration. "
|
|
||||||
"Lower = faster optimization (10–15 is a good trade-off). "
|
|
||||||
"Final generation always uses the steps parameter below.",
|
|
||||||
}),
|
|
||||||
"n_grad_steps": ("INT", {
|
|
||||||
"default": 5, "min": 1, "max": 50,
|
|
||||||
"tooltip": "ODE steps to differentiate through (truncated BPTT). "
|
|
||||||
"Higher = more accurate gradient, more VRAM. "
|
|
||||||
"Must be ≤ n_ode_steps. 5 is a good default.",
|
|
||||||
}),
|
|
||||||
"style_weight": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.05,
|
|
||||||
"tooltip": "Weight of the target style style loss. High values push harder toward "
|
|
||||||
"target style style but add noise. Start at 0.1 and increase slowly.",
|
|
||||||
}),
|
|
||||||
"gram_weight": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
|
|
||||||
"tooltip": "Weight of the Gram matrix (timbral texture) loss relative to "
|
|
||||||
"the mean spectrum loss. 0 = mean spectrum only (less noise). "
|
|
||||||
"0.1 adds texture matching but can introduce white noise.",
|
|
||||||
}),
|
|
||||||
"anchor_weight": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1,
|
|
||||||
"tooltip": "L2 penalty keeping x0 near its initial N(0,1) noise. "
|
|
||||||
"Prevents optimization from pushing x0 out of the flow's "
|
|
||||||
"expected distribution (which causes white noise). "
|
|
||||||
"Higher = cleaner audio, weaker style. 1.0 is a safe default.",
|
|
||||||
}),
|
|
||||||
"steps": ("INT", {
|
|
||||||
"default": 25, "min": 1, "max": 200,
|
|
||||||
"tooltip": "Euler steps for the final generation pass (after optimization).",
|
|
||||||
}),
|
|
||||||
"cfg_strength": ("FLOAT", {
|
|
||||||
"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"normalize": ("BOOLEAN", {"default": True}),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
OUTPUT_TOOLTIPS = ("DITTO-optimized audio — x_0 steered toward target style style.",)
|
|
||||||
FUNCTION = "optimize"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"DITTO inference-time noise optimization (arXiv:2401.12179). "
|
|
||||||
"Optimizes the initial noise latent x_0 to match target style reference clips "
|
|
||||||
"via mel statistics style loss, backpropagating through the ODE. "
|
|
||||||
"All model weights frozen — zero quality degradation risk."
|
|
||||||
)
|
|
||||||
|
|
||||||
def optimize(self, model, features, prompt, negative_prompt,
|
|
||||||
reference_dir, n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize=True, target_lufs=-27.0):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
net_generator = model["generator"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
mel_converter = feature_utils.mel_converter
|
|
||||||
|
|
||||||
# Validate variant match
|
|
||||||
feat_variant = features.get("variant")
|
|
||||||
if feat_variant is not None and feat_variant != model["variant"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"[DITTO] Variant mismatch: features='{feat_variant}' model='{model['variant']}'. "
|
|
||||||
f"Re-run Feature Extractor."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not prompt or not prompt.strip():
|
|
||||||
prompt = features.get("prompt", "")
|
|
||||||
|
|
||||||
# Resolve duration and seq_cfg
|
|
||||||
duration = features.get("duration", 0)
|
|
||||||
if duration <= 0:
|
|
||||||
raise ValueError("[DITTO] Features contain no duration field.")
|
|
||||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
|
||||||
sample_rate = seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
# Load reference clips and encode to latent space.
|
|
||||||
# Style loss is computed in latent space (after net_generator.unnormalize)
|
|
||||||
# rather than mel space — this avoids backpropagating through the VAE
|
|
||||||
# decoder (which is @torch.inference_mode() and produces noisy gradients).
|
|
||||||
ref_dir = Path(reference_dir.strip())
|
|
||||||
if not ref_dir.is_absolute():
|
|
||||||
ref_dir = Path(folder_paths.models_dir) / ref_dir
|
|
||||||
if not ref_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[DITTO] reference_dir not found: {ref_dir}")
|
|
||||||
|
|
||||||
ref_files = []
|
|
||||||
for ext in ("*.wav", "*.flac", "*.mp3", "*.ogg"):
|
|
||||||
ref_files.extend(ref_dir.rglob(ext))
|
|
||||||
if not ref_files:
|
|
||||||
raise FileNotFoundError(f"[DITTO] No audio files in reference_dir: {ref_dir}")
|
|
||||||
|
|
||||||
if not hasattr(feature_utils.tod.vae, "encoder"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"[DITTO] VAE encoder not available — model was loaded with need_vae_encoder=False. "
|
|
||||||
"Reload the model with the encoder enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[DITTO] Loading {len(ref_files)} reference clips...", flush=True)
|
|
||||||
mel_converter.to(device, torch.float32) # cuFFT requires float32
|
|
||||||
|
|
||||||
ref_latents = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for rf in ref_files:
|
|
||||||
try:
|
|
||||||
wav, sr = _load_wav(rf)
|
|
||||||
if wav.shape[0] > 1:
|
|
||||||
wav = wav.mean(0, keepdim=True)
|
|
||||||
if sr != sample_rate:
|
|
||||||
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
|
||||||
wav = wav.squeeze(0).to(device, torch.float32)
|
|
||||||
mel = mel_converter(wav.unsqueeze(0)).to(dtype) # [1, n_mels, T_mel]
|
|
||||||
# encode → sample → VAE latent space (matches unnormalize(x) in loss)
|
|
||||||
z = feature_utils.tod.encode(mel) # DiagonalGaussianDistribution
|
|
||||||
z_sample = z.sample().transpose(1, 2) # [1, T_lat, C_lat]
|
|
||||||
ref_latents.append(z_sample.to(dtype).squeeze(0).clone()) # [T_lat, C_lat]
|
|
||||||
except Exception as e:
|
|
||||||
print(f" [DITTO] Skip {rf.name}: {e}", flush=True)
|
|
||||||
|
|
||||||
if not ref_latents:
|
|
||||||
raise RuntimeError("[DITTO] No usable reference clips.")
|
|
||||||
|
|
||||||
# Precompute reference latent statistics (done once — detached, no grad)
|
|
||||||
with torch.no_grad():
|
|
||||||
all_means = torch.stack([z.mean(dim=0) for z in ref_latents])
|
|
||||||
ref_mean = all_means.mean(0) # [C_lat]
|
|
||||||
all_grams = [(z.T @ z) / z.shape[0] for z in ref_latents]
|
|
||||||
ref_gram = torch.stack(all_grams).mean(0) # [C_lat, C_lat]
|
|
||||||
|
|
||||||
print(f"[DITTO] Reference latent stats from {len(ref_latents)} clips "
|
|
||||||
f"n_opt={n_opt_steps} lr={opt_lr} ode_steps={n_ode_steps} "
|
|
||||||
f"grad_steps={n_grad_steps}", flush=True)
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(device)
|
|
||||||
feature_utils.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(n_opt_steps + steps)
|
|
||||||
|
|
||||||
_result = [None]
|
|
||||||
_exc = [None]
|
|
||||||
|
|
||||||
def _worker():
|
|
||||||
try:
|
|
||||||
_result[0] = _do_optimize(
|
|
||||||
net_generator, feature_utils, mel_converter,
|
|
||||||
features, prompt, negative_prompt,
|
|
||||||
ref_mean, ref_gram,
|
|
||||||
seq_cfg, sample_rate, device, dtype,
|
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize, target_lufs, pbar,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_exc[0] = e
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
t = threading.Thread(target=_worker, daemon=True)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(get_offload_device())
|
|
||||||
feature_utils.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
if _exc[0] is not None:
|
|
||||||
raise _exc[0]
|
|
||||||
return (_result[0],)
|
|
||||||
|
|
||||||
|
|
||||||
def _do_optimize(net_generator, feature_utils, mel_converter,
|
|
||||||
features, prompt, negative_prompt,
|
|
||||||
ref_mean, ref_gram,
|
|
||||||
seq_cfg, sample_rate, device, dtype,
|
|
||||||
n_opt_steps, opt_lr, n_ode_steps, n_grad_steps,
|
|
||||||
style_weight, gram_weight, anchor_weight, steps, cfg_strength, seed,
|
|
||||||
normalize, target_lufs, pbar):
|
|
||||||
"""Optimization loop — runs in a fresh thread (no inference_mode active)."""
|
|
||||||
|
|
||||||
# Strip inference flags from ref stats (came from main thread) and cast to
|
|
||||||
# model dtype. ref_mean/ref_gram are float32 (computed via cuFFT mel path);
|
|
||||||
# mel_gen is model dtype (bfloat16). Mixed-dtype loss → float32 gradient →
|
|
||||||
# "Found dtype Float but expected BFloat16" in backward through bfloat16 ops.
|
|
||||||
ref_mean = ref_mean.clone().detach().to(dtype)
|
|
||||||
ref_gram = ref_gram.clone().detach().to(dtype)
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
clip_f = features["clip_features"].to(device, dtype).clone()
|
|
||||||
sync_f = features["sync_features"].to(device, dtype).clone()
|
|
||||||
|
|
||||||
# Strip inference-mode flags from all model weights and buffers BEFORE any
|
|
||||||
# forward pass. Parameters were loaded in ComfyUI's inference_mode context;
|
|
||||||
# operations on inference tensors produce inference tensors, so conditions
|
|
||||||
# computed from tainted weights would also be tainted. clone() outside
|
|
||||||
# inference_mode produces a normal tensor regardless of the source flag.
|
|
||||||
def _strip_inference(module):
|
|
||||||
for mod in module.modules():
|
|
||||||
for name, buf in list(mod._buffers.items()):
|
|
||||||
if buf is not None:
|
|
||||||
mod._buffers[name] = buf.clone()
|
|
||||||
for name, param in list(mod._parameters.items()):
|
|
||||||
if param is not None:
|
|
||||||
mod._parameters[name] = torch.nn.Parameter(
|
|
||||||
param.data.clone(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
_strip_inference(net_generator)
|
|
||||||
_strip_inference(feature_utils)
|
|
||||||
_strip_inference(mel_converter)
|
|
||||||
|
|
||||||
net_generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=clip_f.shape[1],
|
|
||||||
sync_seq_len=sync_f.shape[1],
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
text_clip = feature_utils.encode_text_clip([prompt])
|
|
||||||
|
|
||||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
|
||||||
if negative_prompt.strip() else None
|
|
||||||
|
|
||||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
|
||||||
empty_conditions = net_generator.get_empty_conditions(
|
|
||||||
bs=1, negative_text_features=neg_text_clip
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clone all tensors inside conditions/empty_conditions to ensure no inference
|
|
||||||
# flags survived from intermediate computations inside preprocess_conditions.
|
|
||||||
def _clone_nested(obj):
|
|
||||||
if isinstance(obj, torch.Tensor):
|
|
||||||
return obj.clone()
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
return {k: _clone_nested(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
return type(obj)(_clone_nested(v) for v in obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
conditions = _clone_nested(conditions)
|
|
||||||
empty_conditions = _clone_nested(empty_conditions)
|
|
||||||
|
|
||||||
# Initial noise — x_0 is the parameter we optimize
|
|
||||||
x0_init = torch.randn(
|
|
||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
|
||||||
device=device, dtype=dtype,
|
|
||||||
)
|
|
||||||
x0 = torch.nn.Parameter(x0_init.clone())
|
|
||||||
x0_init = x0_init.detach() # anchor — kept fixed, no grad
|
|
||||||
optimizer = torch.optim.Adam([x0], lr=opt_lr)
|
|
||||||
|
|
||||||
# n_grad_steps must not exceed n_ode_steps
|
|
||||||
n_grad_steps = min(n_grad_steps, n_ode_steps)
|
|
||||||
n_free_steps = n_ode_steps - n_grad_steps # steps run without gradient
|
|
||||||
|
|
||||||
ts = torch.linspace(0.0, 1.0, n_ode_steps + 1, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
print(f"[DITTO] Optimizing x_0 "
|
|
||||||
f"free_steps={n_free_steps} grad_steps={n_grad_steps}", flush=True)
|
|
||||||
|
|
||||||
# Freeze all model weights (double-check — should already be frozen at inference)
|
|
||||||
net_generator.requires_grad_(False)
|
|
||||||
feature_utils.requires_grad_(False)
|
|
||||||
mel_converter.requires_grad_(False)
|
|
||||||
|
|
||||||
for opt_step in range(n_opt_steps):
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
|
|
||||||
# ── Phase 1: run first (n_ode_steps - n_grad_steps) steps without grad ──
|
|
||||||
# Detach from x0 so Phase 1 does not build a computation graph.
|
|
||||||
with torch.no_grad():
|
|
||||||
x = x0.detach()
|
|
||||||
for i in range(n_free_steps):
|
|
||||||
t = ts[i]
|
|
||||||
dt = ts[i + 1] - t
|
|
||||||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
|
||||||
x = x + dt * flow
|
|
||||||
|
|
||||||
# Straight-through estimator: reconnect x to x0's gradient path by
|
|
||||||
# adding the zero tensor (x0 - x0.detach()). This adds zero value but
|
|
||||||
# creates a grad_fn pointing back to x0, so loss.backward() will
|
|
||||||
# propagate ∂loss/∂x (at the Phase-1/2 boundary) directly to x0.grad.
|
|
||||||
# The approximation is ∂x_prefix/∂x0 ≈ I — the no-grad prefix is
|
|
||||||
# treated as identity for gradient purposes (truncated BPTT).
|
|
||||||
#
|
|
||||||
# x may carry an inference tensor flag from Phase 1 (derived from
|
|
||||||
# conditions which were built outside inference_mode but may have
|
|
||||||
# propagated the flag). .clone() strips it so the STE addition does
|
|
||||||
# not try to save an inference tensor for backward.
|
|
||||||
x = x.clone()
|
|
||||||
x = x + (x0 - x0.detach())
|
|
||||||
|
|
||||||
# ── Phase 2: run last n_grad_steps with gradient + checkpointing ──
|
|
||||||
for i in range(n_free_steps, n_ode_steps):
|
|
||||||
t = ts[i]
|
|
||||||
dt = ts[i + 1] - t
|
|
||||||
|
|
||||||
# Gradient checkpointing: recompute forward during backward,
|
|
||||||
# avoiding storage of DiT activations for each step.
|
|
||||||
def _ode_step(x_in, t=t):
|
|
||||||
return net_generator.ode_wrapper(t, x_in, conditions, empty_conditions, cfg_strength)
|
|
||||||
|
|
||||||
flow = torch.utils.checkpoint.checkpoint(
|
|
||||||
_ode_step, x, use_reentrant=False
|
|
||||||
)
|
|
||||||
x = x + dt * flow
|
|
||||||
|
|
||||||
# ── Style loss in latent space ───────────────────────────────────────
|
|
||||||
# Unnormalize x back to VAE latent space — fully differentiable, no
|
|
||||||
# decode needed. ref_mean/ref_gram are computed from encoded reference
|
|
||||||
# clips in the same space. Avoids backprop through VAE decoder which
|
|
||||||
# is @torch.inference_mode() and produces noisy gradients.
|
|
||||||
x_un = net_generator.unnormalize(x) # [1, T_lat, C_lat]
|
|
||||||
style_loss = style_weight * _latent_style_loss(x_un.squeeze(0), ref_mean, ref_gram, gram_weight)
|
|
||||||
|
|
||||||
# Anchor regularization — penalize x0 drifting from its initial N(0,1)
|
|
||||||
# value. Flow matching ODE expects x0 ~ N(0,1); large deviations push
|
|
||||||
# the ODE into an out-of-distribution region that decodes as white noise.
|
|
||||||
anchor_loss = anchor_weight * F.mse_loss(x0, x0_init)
|
|
||||||
loss = style_loss + anchor_loss
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward() # gradient flows through Phase 2 + STE back to x0.grad
|
|
||||||
torch.nn.utils.clip_grad_norm_([x0], 1.0)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if (opt_step + 1) % max(1, n_opt_steps // 10) == 0:
|
|
||||||
print(f"[DITTO] {opt_step+1}/{n_opt_steps} "
|
|
||||||
f"style={style_loss.item():.4f} anchor={anchor_loss.item():.4f} "
|
|
||||||
f"x0_std={x0.data.std().item():.3f}", flush=True)
|
|
||||||
|
|
||||||
# ── Final generation with optimized x_0 ─────────────────────────────────
|
|
||||||
print(f"[DITTO] Optimization done. Final generation ({steps} steps)...", flush=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
fm_ts = torch.linspace(0.0, 1.0, steps + 1, device=device, dtype=dtype)
|
|
||||||
x = x0.detach()
|
|
||||||
for i in range(steps):
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
t = fm_ts[i]
|
|
||||||
dt = fm_ts[i + 1] - t
|
|
||||||
flow = net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
|
||||||
x = x + dt * flow
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
x1_unnorm = net_generator.unnormalize(x)
|
|
||||||
spec = feature_utils.decode(x1_unnorm)
|
|
||||||
audio = feature_utils.vocode(spec)
|
|
||||||
|
|
||||||
print(f"[DITTO] latent stats: mean={x.float().mean():.4f} std={x.float().std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
audio = audio.float()
|
|
||||||
if audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
|
||||||
audio = audio.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
if normalize:
|
|
||||||
target_rms = 10 ** (target_lufs / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = audio * (target_rms / rms)
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
if peak > 1.0:
|
|
||||||
audio = audio / peak
|
|
||||||
|
|
||||||
print(f"[DITTO] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
|
||||||
return {"waveform": audio.cpu(), "sample_rate": sample_rate}
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
import os
|
|
||||||
import hashlib
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
|
||||||
|
|
||||||
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
|
||||||
_CLIP_SIZE = 384
|
|
||||||
_SYNC_SIZE = 224
|
|
||||||
_CLIP_FPS = 8
|
|
||||||
_SYNC_FPS = 25
|
|
||||||
|
|
||||||
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
|
|
||||||
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
|
||||||
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def _sample_frames(video, source_fps, target_fps, duration):
|
|
||||||
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
|
|
||||||
T = video.shape[0]
|
|
||||||
n_out = max(1, int(duration * target_fps))
|
|
||||||
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
|
||||||
return video[indices]
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_frames(frames, size):
|
|
||||||
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
|
||||||
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
|
||||||
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
|
|
||||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
|
||||||
"""
|
|
||||||
Apply a ComfyUI MASK to resized frames.
|
|
||||||
|
|
||||||
frames: [N, C, H, W] float [0,1]
|
|
||||||
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
|
|
||||||
source_fps: original video fps (for accurate temporal sampling)
|
|
||||||
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
|
|
||||||
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
|
|
||||||
|
|
||||||
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
|
|
||||||
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
|
|
||||||
"""
|
|
||||||
N, C, H, W = frames.shape
|
|
||||||
M = mask.shape[0]
|
|
||||||
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
|
|
||||||
if mask_f.shape[2] != H or mask_f.shape[3] != W:
|
|
||||||
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
|
|
||||||
|
|
||||||
# Temporal sampling — use same index formula as _sample_frames for accuracy
|
|
||||||
if M == 1:
|
|
||||||
mask_f = mask_f.expand(N, -1, -1, -1)
|
|
||||||
else:
|
|
||||||
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
|
|
||||||
mask_f = mask_f[indices] # [N, 1, H, W]
|
|
||||||
|
|
||||||
mask_f = mask_f.to(frames.device)
|
|
||||||
|
|
||||||
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
|
|
||||||
alpha = 1.0 - mask_strength * (1.0 - mask_f)
|
|
||||||
return frames * alpha + 0.5 * (1.0 - alpha)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_named_path(cache_dir: str, name: str) -> str:
|
|
||||||
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
|
|
||||||
# Sanitize: replace path separators so the name stays inside cache_dir
|
|
||||||
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
|
|
||||||
i = 1
|
|
||||||
while True:
|
|
||||||
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
|
|
||||||
if not os.path.exists(p):
|
|
||||||
return p
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
|
||||||
h = hashlib.sha256()
|
|
||||||
raw = video_tensor.cpu().numpy().tobytes()
|
|
||||||
n = len(raw)
|
|
||||||
chunk = 512 * 1024 # 512 KB per sample
|
|
||||||
h.update(raw[:chunk])
|
|
||||||
h.update(raw[n // 2: n // 2 + chunk])
|
|
||||||
h.update(raw[max(0, n - chunk):])
|
|
||||||
if mask is not None:
|
|
||||||
raw_m = mask.cpu().numpy().tobytes()
|
|
||||||
nm = len(raw_m)
|
|
||||||
chunk_m = 256 * 1024
|
|
||||||
h.update(raw_m[:chunk_m])
|
|
||||||
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
|
|
||||||
h.update(raw_m[max(0, nm - chunk_m):])
|
|
||||||
h.update(str(round(mask_strength, 4)).encode())
|
|
||||||
h.update(str(mask_clip).encode())
|
|
||||||
h.update(str(mask_sync).encode())
|
|
||||||
h.update(prompt.encode())
|
|
||||||
h.update(str(fps).encode())
|
|
||||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
|
||||||
h.update(variant.encode())
|
|
||||||
return h.hexdigest()[:32]
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaFeatureExtractor:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"video": ("IMAGE",),
|
|
||||||
"prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": True,
|
|
||||||
"tooltip": "Describes the sounds to generate. Used to focus the visual sync features on motion relevant to the prompt — more specific prompts produce cleaner audio sync. Wire the prompt output directly to the Sampler so you only type it once.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"video_info": ("VHS_VIDEOINFO", {
|
|
||||||
"tooltip": "VHS_VIDEOINFO from VHS LoadVideo. Automatically sets the correct source fps — always connect this when loading video with VHS nodes.",
|
|
||||||
}),
|
|
||||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001,
|
|
||||||
"tooltip": "Source fps of the input video. Ignored when video_info is connected."}),
|
|
||||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
|
||||||
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
|
||||||
"cache_dir": ("STRING", {"default": "",
|
|
||||||
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
|
||||||
"name": ("STRING", {"default": "",
|
|
||||||
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
|
|
||||||
"mask": ("MASK", {
|
|
||||||
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
|
||||||
}),
|
|
||||||
"mask_strength": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
|
|
||||||
}),
|
|
||||||
"mask_clip": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
|
|
||||||
}),
|
|
||||||
"mask_sync": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
|
||||||
RETURN_NAMES = ("features", "fps", "prompt")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Extracted feature bundle — connect to Sampler.",
|
|
||||||
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
|
||||||
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
|
||||||
)
|
|
||||||
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
|
|
||||||
FUNCTION = "extract_features"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
|
||||||
|
|
||||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
|
||||||
duration=0.0, cache_dir="", name="", mask=None,
|
|
||||||
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
|
||||||
if video_info is not None:
|
|
||||||
fps = video_info["loaded_fps"]
|
|
||||||
|
|
||||||
T = video.shape[0]
|
|
||||||
if duration <= 0:
|
|
||||||
duration = T / fps
|
|
||||||
duration = min(duration, T / fps) # clamp to actual video length
|
|
||||||
|
|
||||||
if not prompt.strip():
|
|
||||||
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
|
||||||
|
|
||||||
# Cache
|
|
||||||
if not cache_dir:
|
|
||||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if name.strip():
|
|
||||||
# Named mode: always extract and save to an incremented filename
|
|
||||||
cached_path = _resolve_named_path(cache_dir, name.strip())
|
|
||||||
else:
|
|
||||||
# Hash mode: skip extraction if identical inputs were already processed
|
|
||||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
|
||||||
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
|
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
|
||||||
if os.path.exists(cached_path):
|
|
||||||
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
|
||||||
cached = _load_cached(cached_path)
|
|
||||||
return (cached, float(fps), cached.get("prompt", prompt))
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
net_video_enc = model["video_enc"]
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
feature_utils.to(device)
|
|
||||||
net_video_enc.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
|
||||||
pbar = comfy.utils.ProgressBar(3)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
|
||||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
|
||||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
|
||||||
if mask is not None and mask_clip:
|
|
||||||
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
|
||||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
|
||||||
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
|
|
||||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
|
|
||||||
|
|
||||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
|
||||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
|
||||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
|
||||||
if mask is not None and mask_sync:
|
|
||||||
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
|
||||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
|
||||||
if sync_frames.shape[0] < 16:
|
|
||||||
pad = 16 - sync_frames.shape[0]
|
|
||||||
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
|
||||||
# Normalize [0,1] → [-1,1]
|
|
||||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
|
||||||
std = _SYNC_STD.to(sync_frames.device)
|
|
||||||
sync_frames = (sync_frames - mean) / std
|
|
||||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
|
||||||
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
|
|
||||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
|
|
||||||
|
|
||||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
|
||||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
|
||||||
pbar.update(1)
|
|
||||||
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
|
||||||
sync_features = net_video_enc.encode_video_with_sync(
|
|
||||||
sync_input, text_f=text_f, text_mask=text_mask
|
|
||||||
) # [1, T_sync, 768]
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
|
||||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
feature_utils.to(get_offload_device())
|
|
||||||
net_video_enc.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
np.savez(
|
|
||||||
cached_path,
|
|
||||||
clip_features=clip_features.cpu().float().numpy(),
|
|
||||||
sync_features=sync_features.cpu().float().numpy(),
|
|
||||||
duration=float(duration),
|
|
||||||
prompt=np.array(prompt),
|
|
||||||
variant=np.array(model["variant"]),
|
|
||||||
)
|
|
||||||
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
|
||||||
|
|
||||||
return ({
|
|
||||||
"clip_features": clip_features.cpu(),
|
|
||||||
"sync_features": sync_features.cpu(),
|
|
||||||
"duration": float(duration),
|
|
||||||
"prompt": prompt,
|
|
||||||
"variant": model["variant"],
|
|
||||||
}, float(fps), prompt)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cached(path):
|
|
||||||
data = np.load(path, allow_pickle=False)
|
|
||||||
features = {
|
|
||||||
"clip_features": torch.from_numpy(data["clip_features"]),
|
|
||||||
"sync_features": torch.from_numpy(data["sync_features"]),
|
|
||||||
"duration": float(data["duration"]),
|
|
||||||
}
|
|
||||||
if "prompt" in data:
|
|
||||||
features["prompt"] = str(data["prompt"])
|
|
||||||
if "variant" in data:
|
|
||||||
features["variant"] = str(data["variant"])
|
|
||||||
return features
|
|
||||||
@@ -1,421 +0,0 @@
|
|||||||
"""SelVA LoRA Evaluator — generates eval samples from multiple adapters for comparison.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "eval_batch_1",
|
|
||||||
"data_dir": "/path/to/features",
|
|
||||||
"output_dir": "/path/to/evals/batch1",
|
|
||||||
"steps": 25,
|
|
||||||
"seed": 42,
|
|
||||||
"adapters": [
|
|
||||||
{"id": "baseline"},
|
|
||||||
{"id": "lr_3e4_10k", "path": "/path/to/adapter_final.pt"},
|
|
||||||
{"id": "lr_5e4_10k", "path": "/path/to/adapter_final.pt"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
Empty / missing "path" = baseline (no LoRA applied).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_eval_sample,
|
|
||||||
_spectral_metrics,
|
|
||||||
_save_spectrogram,
|
|
||||||
_pil_to_tensor,
|
|
||||||
_find_audio,
|
|
||||||
_load_audio,
|
|
||||||
)
|
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
|
||||||
|
|
||||||
|
|
||||||
def _avg_metrics(metrics_list: list) -> dict:
|
|
||||||
"""Average spectral metrics across multiple clips, ignoring None entries."""
|
|
||||||
keys = ["hf_energy_ratio", "spectral_centroid_hz", "spectral_rolloff_hz",
|
|
||||||
"spectral_flatness", "temporal_variance"]
|
|
||||||
valid = [m for m in metrics_list if m]
|
|
||||||
if not valid:
|
|
||||||
return {}
|
|
||||||
return {k: round(float(sum(m[k] for m in valid) / len(valid)), 4) for k in keys}
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_stem(adapter_id: str) -> str:
|
|
||||||
"""Replace characters illegal in filenames."""
|
|
||||||
for ch in r'/\:*?"<>|':
|
|
||||||
adapter_id = adapter_id.replace(ch, "_")
|
|
||||||
return adapter_id
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_metric_comparison(adapter_ids: list, metrics_list: list, output_path: Path):
|
|
||||||
"""Draw a 2×2 grid of horizontal bar charts comparing spectral metrics.
|
|
||||||
|
|
||||||
Saves a PNG to output_path and returns a ComfyUI IMAGE tensor.
|
|
||||||
"""
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|
||||||
|
|
||||||
METRICS = [
|
|
||||||
("hf_energy_ratio", "HF Energy Ratio (>4 kHz)"),
|
|
||||||
("spectral_centroid_hz", "Spectral Centroid (Hz)"),
|
|
||||||
("spectral_flatness", "Spectral Flatness"),
|
|
||||||
("temporal_variance", "Temporal Variance"),
|
|
||||||
]
|
|
||||||
COLORS = [
|
|
||||||
"#4285F4", "#EA4335", "#34A853", "#FBBC05",
|
|
||||||
"#9B59B6", "#1ABC9C", "#E67E22", "#95A5A6",
|
|
||||||
]
|
|
||||||
|
|
||||||
fig = Figure(figsize=(12, max(4, len(adapter_ids) * 0.6 + 2)), dpi=110, tight_layout=True)
|
|
||||||
axes = [fig.add_subplot(2, 2, i + 1) for i in range(4)]
|
|
||||||
|
|
||||||
for ax, (key, title) in zip(axes, METRICS):
|
|
||||||
values = []
|
|
||||||
colors = []
|
|
||||||
for i, m in enumerate(metrics_list):
|
|
||||||
v = m.get(key, 0.0) if m else 0.0
|
|
||||||
values.append(v)
|
|
||||||
colors.append(COLORS[i % len(COLORS)])
|
|
||||||
|
|
||||||
bars = ax.barh(adapter_ids, values, color=colors, height=0.6)
|
|
||||||
ax.set_title(title, fontsize=9)
|
|
||||||
ax.set_xlabel(key, fontsize=8)
|
|
||||||
ax.tick_params(axis="y", labelsize=7)
|
|
||||||
ax.tick_params(axis="x", labelsize=7)
|
|
||||||
|
|
||||||
# Value labels on bars
|
|
||||||
for bar, val in zip(bars, values):
|
|
||||||
w = bar.get_width()
|
|
||||||
ax.text(w * 1.01, bar.get_y() + bar.get_height() / 2,
|
|
||||||
f"{val:.3f}", va="center", ha="left", fontsize=6)
|
|
||||||
|
|
||||||
canvas = FigureCanvasAgg(fig)
|
|
||||||
canvas.draw()
|
|
||||||
canvas.print_figure(str(output_path), dpi=110)
|
|
||||||
|
|
||||||
buf = canvas.buffer_rgba()
|
|
||||||
w, h = canvas.get_width_height()
|
|
||||||
arr = np.frombuffer(buf, dtype=np.uint8).reshape(h, w, 4)[:, :, :3]
|
|
||||||
from PIL import Image
|
|
||||||
return _pil_to_tensor(Image.fromarray(arr))
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraEvaluator:
|
|
||||||
"""Evaluates a batch of LoRA adapters on a fixed reference clip.
|
|
||||||
|
|
||||||
Generates one audio sample per adapter, computes spectral metrics for each,
|
|
||||||
and produces a comparison chart. Use this after a sweep to compare candidates
|
|
||||||
before running the next round of training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_image")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to eval_summary.json — contains spectral metrics per adapter.",
|
|
||||||
"Bar chart comparing spectral metrics across all evaluated adapters.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Evaluates multiple LoRA adapters by generating one audio sample per adapter "
|
|
||||||
"from a fixed reference clip, then collects spectral metrics for comparison. "
|
|
||||||
"Input is a JSON file listing adapter paths. Empty path = baseline (no LoRA)."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"eval_file": ("STRING", {
|
|
||||||
"default": "eval_batch.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to the JSON evaluation spec. Relative paths resolve "
|
|
||||||
"to the ComfyUI output directory. "
|
|
||||||
"Each adapter entry needs an 'id' and an optional 'path'. "
|
|
||||||
"Omit 'path' for a no-LoRA baseline."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, eval_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Resolve and parse the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
eval_path = Path(eval_file.strip())
|
|
||||||
if not eval_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / eval_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / eval_path
|
|
||||||
eval_path = candidate
|
|
||||||
if not eval_path.exists():
|
|
||||||
raise FileNotFoundError(f"[LoRA Evaluator] Eval file not found: {eval_path}")
|
|
||||||
|
|
||||||
spec = json.loads(eval_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "adapters" not in spec or not spec["adapters"]:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'adapters' list is missing or empty.")
|
|
||||||
for i, a in enumerate(spec["adapters"]):
|
|
||||||
if "id" not in a:
|
|
||||||
raise ValueError(f"[LoRA Evaluator] Adapter at index {i} missing 'id'.")
|
|
||||||
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'data_dir' is required.")
|
|
||||||
if "output_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Evaluator] 'output_dir' is required.")
|
|
||||||
|
|
||||||
name = spec.get("name", eval_path.stem)
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_dir = _resolve_path(spec["output_dir"])
|
|
||||||
steps = int(spec.get("steps", 25))
|
|
||||||
seed = int(spec.get("seed", 42))
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[LoRA Evaluator] '{name}': {len(spec['adapters'])} adapter(s)", flush=True)
|
|
||||||
print(f"[LoRA Evaluator] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[LoRA Evaluator] output_dir = {output_dir}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Prepare dataset (VAE encode once)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Collect reference metrics for all dataset clips
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
import shutil
|
|
||||||
npz_files = sorted(data_dir.glob("*.npz"))
|
|
||||||
ref_dir = output_dir / "reference"
|
|
||||||
ref_dir.mkdir(exist_ok=True)
|
|
||||||
ref_clips = [] # list of {clip, wav_path, spectral_metrics}
|
|
||||||
|
|
||||||
print(f"[LoRA Evaluator] Computing reference metrics for {len(npz_files)} clip(s)...",
|
|
||||||
flush=True)
|
|
||||||
for npz_path in npz_files:
|
|
||||||
audio_path = _find_audio(npz_path)
|
|
||||||
if audio_path is None:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
ref_wav = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
|
||||||
ref_wav = ref_wav.unsqueeze(0) # [1, L]
|
|
||||||
ref_out = ref_dir / f"{npz_path.stem}{audio_path.suffix}"
|
|
||||||
shutil.copy2(str(audio_path), str(ref_out))
|
|
||||||
metrics = _spectral_metrics(ref_wav, seq_cfg.sampling_rate)
|
|
||||||
ref_clips.append({
|
|
||||||
"clip": npz_path.stem,
|
|
||||||
"wav_path": str(ref_out),
|
|
||||||
"spectral_metrics": metrics,
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[LoRA Evaluator] Reference {npz_path.name} failed: {e}", flush=True)
|
|
||||||
|
|
||||||
# Average reference metrics across all clips
|
|
||||||
ref_avg = _avg_metrics([c["spectral_metrics"] for c in ref_clips])
|
|
||||||
print(f"[LoRA Evaluator] Reference avg — "
|
|
||||||
f"centroid={ref_avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
|
||||||
f"hf={ref_avg.get('hf_energy_ratio', 0):.3f} "
|
|
||||||
f"flatness={ref_avg.get('spectral_flatness', 0):.4f}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build summary skeleton
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary = {
|
|
||||||
"name": name,
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
"n_clips": len(ref_clips),
|
|
||||||
"steps": steps,
|
|
||||||
"seed": seed,
|
|
||||||
"reference_avg": ref_avg,
|
|
||||||
"reference_clips": ref_clips,
|
|
||||||
"adapters": [],
|
|
||||||
}
|
|
||||||
summary_path = output_dir / "eval_summary.json"
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Per-adapter evaluation loop (all clips)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(dataset)
|
|
||||||
pbar = comfy.utils.ProgressBar(len(spec["adapters"]) * n_clips)
|
|
||||||
|
|
||||||
for adapter_spec in spec["adapters"]:
|
|
||||||
adapter_id = adapter_spec["id"]
|
|
||||||
adapter_path = (adapter_spec.get("path") or "").strip()
|
|
||||||
safe_id = _safe_stem(adapter_id)
|
|
||||||
clip_dir = output_dir / safe_id
|
|
||||||
clip_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
record = {
|
|
||||||
"id": adapter_id,
|
|
||||||
"path": adapter_path or None,
|
|
||||||
"meta": None,
|
|
||||||
"clips": [],
|
|
||||||
"avg_metrics": None,
|
|
||||||
"status": "running",
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"\n[LoRA Evaluator] ── '{adapter_id}' ({n_clips} clips) ──", flush=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False):
|
|
||||||
generator = copy.deepcopy(model["generator"])
|
|
||||||
|
|
||||||
if adapter_path:
|
|
||||||
pt_path = Path(adapter_path)
|
|
||||||
if not pt_path.is_absolute():
|
|
||||||
pt_path = Path(folder_paths.base_path) / pt_path
|
|
||||||
if not pt_path.exists():
|
|
||||||
raise FileNotFoundError(f"Adapter not found: {pt_path}")
|
|
||||||
|
|
||||||
ckpt = torch.load(str(pt_path), map_location="cpu",
|
|
||||||
weights_only=False)
|
|
||||||
if isinstance(ckpt, dict) and "state_dict" in ckpt:
|
|
||||||
state_dict = ckpt["state_dict"]
|
|
||||||
meta = ckpt.get("meta", {})
|
|
||||||
else:
|
|
||||||
state_dict = ckpt
|
|
||||||
meta = {}
|
|
||||||
|
|
||||||
rank = int(meta.get("rank", 16))
|
|
||||||
alpha = float(meta.get("alpha", float(rank)))
|
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
|
||||||
dropout = float(meta.get("lora_dropout", 0.0))
|
|
||||||
use_rslora = meta.get("use_rslora", False)
|
|
||||||
record["meta"] = {"rank": rank, "alpha": alpha, "target": target}
|
|
||||||
|
|
||||||
# Always use standard init for loading — PiSSA checkpoints
|
|
||||||
# include linear.weight (residual) in state_dict
|
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
|
||||||
target_suffixes=tuple(target), dropout=dropout,
|
|
||||||
init_mode="standard", use_rslora=use_rslora)
|
|
||||||
if n == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"apply_lora matched 0 layers (target={target})"
|
|
||||||
)
|
|
||||||
load_lora(generator, state_dict)
|
|
||||||
print(f"[LoRA Evaluator] Loaded {pt_path.name} "
|
|
||||||
f"(rank={rank}, {n} layers)", flush=True)
|
|
||||||
else:
|
|
||||||
print("[LoRA Evaluator] Baseline (no LoRA)", flush=True)
|
|
||||||
|
|
||||||
generator = generator.to(device, dtype)
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=seq_cfg.clip_seq_len,
|
|
||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_metrics_list = []
|
|
||||||
for clip_idx in range(n_clips):
|
|
||||||
clip_stem = npz_files[clip_idx].stem
|
|
||||||
wav, sr = _eval_sample(
|
|
||||||
generator, feature_utils_orig, dataset,
|
|
||||||
seq_cfg, device, dtype,
|
|
||||||
num_steps=steps, seed=seed, clip_idx=clip_idx,
|
|
||||||
)
|
|
||||||
if wav is None:
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
wav_path = clip_dir / f"{clip_stem}.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
|
||||||
|
|
||||||
metrics = _spectral_metrics(wav, sr)
|
|
||||||
clip_metrics_list.append(metrics)
|
|
||||||
record["clips"].append({
|
|
||||||
"clip": clip_stem,
|
|
||||||
"wav_path": str(wav_path),
|
|
||||||
"spectral_metrics": metrics,
|
|
||||||
})
|
|
||||||
print(f" [{clip_idx+1}/{n_clips}] {clip_stem} "
|
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
|
||||||
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
record["avg_metrics"] = _avg_metrics(clip_metrics_list)
|
|
||||||
record["status"] = "completed"
|
|
||||||
avg = record["avg_metrics"]
|
|
||||||
print(f"[LoRA Evaluator] '{adapter_id}' avg — "
|
|
||||||
f"centroid={avg.get('spectral_centroid_hz', 0):.0f}Hz "
|
|
||||||
f"hf={avg.get('hf_energy_ratio', 0):.3f} "
|
|
||||||
f"flatness={avg.get('spectral_flatness', 0):.4f}", flush=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
record["status"] = "failed"
|
|
||||||
record["error"] = str(e)
|
|
||||||
print(f"[LoRA Evaluator] '{adapter_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
pbar.update(n_clips - len(record["clips"]))
|
|
||||||
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
del generator
|
|
||||||
except NameError:
|
|
||||||
pass
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
summary["adapters"].append(record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[LoRA Evaluator] Done. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Comparison chart
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
completed = [r for r in summary["adapters"] if r.get("status") == "completed"]
|
|
||||||
if completed:
|
|
||||||
ids = ["reference"] + [r["id"] for r in completed]
|
|
||||||
metrics_list = [summary["reference_avg"]] + [r["avg_metrics"] for r in completed]
|
|
||||||
chart_path = output_dir / "metric_comparison.png"
|
|
||||||
comparison = _draw_metric_comparison(ids, metrics_list, chart_path)
|
|
||||||
print(f"[LoRA Evaluator] Comparison chart: {chart_path}", flush=True)
|
|
||||||
else:
|
|
||||||
from PIL import Image
|
|
||||||
comparison = _pil_to_tensor(Image.new("RGB", (400, 200), (255, 255, 255)))
|
|
||||||
|
|
||||||
return (str(summary_path), comparison)
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
import copy
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
from selva_core.model.lora import apply_lora, load_lora
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraLoader:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"adapter_path": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
|
|
||||||
}),
|
|
||||||
"strength": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
|
|
||||||
"tooltip": "Scale applied to all LoRA contributions. "
|
|
||||||
"1.0 = full adapter strength. "
|
|
||||||
"0.0 = effectively disables the adapter. "
|
|
||||||
"Values above 1.0 exaggerate the effect.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("SELVA_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
|
|
||||||
"The base model is not modified — a shallow copy of the model bundle is returned."
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
|
|
||||||
if not adapter_path.strip():
|
|
||||||
raise ValueError("[SelVA LoRA] adapter_path is empty.")
|
|
||||||
|
|
||||||
# Resolve path: allow absolute or relative to ComfyUI base
|
|
||||||
from pathlib import Path
|
|
||||||
p = Path(adapter_path)
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.base_path) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
|
|
||||||
|
|
||||||
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
# Support both raw state_dict and {state_dict, meta} formats
|
|
||||||
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
|
||||||
state_dict = checkpoint["state_dict"]
|
|
||||||
meta = checkpoint.get("meta", {})
|
|
||||||
else:
|
|
||||||
state_dict = checkpoint
|
|
||||||
meta = {}
|
|
||||||
|
|
||||||
rank = int(meta.get("rank", 16))
|
|
||||||
alpha = float(meta.get("alpha", float(rank)))
|
|
||||||
target = list(meta.get("target", ["attn.qkv"]))
|
|
||||||
init_mode = meta.get("init_mode", "standard")
|
|
||||||
use_rslora = meta.get("use_rslora", False)
|
|
||||||
|
|
||||||
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
|
||||||
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} "
|
|
||||||
f"init={init_mode} rslora={use_rslora} strength={strength}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Shallow-copy the model bundle so the original generator is not mutated
|
|
||||||
patched = {**model}
|
|
||||||
generator = copy.deepcopy(model["generator"])
|
|
||||||
|
|
||||||
# For PiSSA, use standard init (the base weights will be overwritten
|
|
||||||
# by load_state_dict since the checkpoint includes linear.weight)
|
|
||||||
n = apply_lora(generator, rank=rank, alpha=alpha,
|
|
||||||
target_suffixes=tuple(target),
|
|
||||||
init_mode="standard", use_rslora=use_rslora)
|
|
||||||
if n == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[SelVA LoRA] No layers matched target={target}. "
|
|
||||||
"Check that the adapter was trained with the same target suffixes."
|
|
||||||
)
|
|
||||||
load_lora(generator, state_dict)
|
|
||||||
|
|
||||||
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
|
|
||||||
norms = [p.norm().item() for name, p in generator.named_parameters()
|
|
||||||
if "lora_A" in name]
|
|
||||||
if norms:
|
|
||||||
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
|
|
||||||
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
|
|
||||||
else:
|
|
||||||
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
|
|
||||||
|
|
||||||
# Apply strength scaling: multiply all lora_B params by strength
|
|
||||||
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
|
|
||||||
if strength != 1.0:
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, param in generator.named_parameters():
|
|
||||||
if "lora_B" in name:
|
|
||||||
param.mul_(strength)
|
|
||||||
|
|
||||||
generator.to(model["generator"].parameters().__next__().device)
|
|
||||||
patched["generator"] = generator
|
|
||||||
|
|
||||||
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
|
|
||||||
return (patched,)
|
|
||||||
@@ -1,539 +0,0 @@
|
|||||||
"""SelVA LoRA Scheduler — runs a sweep of training experiments from a JSON file.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
The dataset is loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image showing all runs on the same axes.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "tier1_sweep",
|
|
||||||
"description": "optional human note",
|
|
||||||
"data_dir": "dataset/dog_bark",
|
|
||||||
"output_root": "lora_output/tier1_sweep",
|
|
||||||
"base": { "rank": 16, "lr": 1e-4, "steps": 2000, ... },
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "..."},
|
|
||||||
{"id": "lora_plus_16", "lora_plus_ratio": 16.0},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
SelvaLoraTrainer,
|
|
||||||
SkipExperiment,
|
|
||||||
_prepare_dataset,
|
|
||||||
_smooth_losses,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_system_info() -> dict:
|
|
||||||
"""Collect GPU / torch version info for the summary header."""
|
|
||||||
info: dict = {
|
|
||||||
"torch_version": torch.__version__,
|
|
||||||
"cuda_version": torch.version.cuda or "N/A",
|
|
||||||
"gpu_name": None,
|
|
||||||
"gpu_vram_gb": None,
|
|
||||||
}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
info["gpu_name"] = torch.cuda.get_device_name(0)
|
|
||||||
props = torch.cuda.get_device_properties(0)
|
|
||||||
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
# Defaults mirror SelvaLoraTrainer INPUT_TYPES defaults
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"alpha": 0.0,
|
|
||||||
"target": "attn.qkv",
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"grad_accum": 1,
|
|
||||||
"save_every": 500,
|
|
||||||
"resume_path": "",
|
|
||||||
"seed": 42,
|
|
||||||
"timestep_mode": "uniform",
|
|
||||||
"logit_normal_sigma": 1.0,
|
|
||||||
"curriculum_switch": 0.6,
|
|
||||||
"lora_dropout": 0.0,
|
|
||||||
"lora_plus_ratio": 1.0,
|
|
||||||
"lr_schedule": "constant",
|
|
||||||
"init_mode": "pissa",
|
|
||||||
"use_rslora": True,
|
|
||||||
"latent_mixup_alpha": 0.0,
|
|
||||||
"latent_noise_sigma": 0.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Palette for comparison chart: one color per experiment (cycles if > 8)
|
|
||||||
_PALETTE = [
|
|
||||||
(66, 133, 244), # blue
|
|
||||||
(234, 67, 53), # red
|
|
||||||
(52, 168, 83), # green
|
|
||||||
(251, 188, 5), # yellow
|
|
||||||
(155, 89, 182), # purple
|
|
||||||
(26, 188, 156), # teal
|
|
||||||
(230, 126, 34), # orange
|
|
||||||
(149, 165, 166), # grey
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
"""Resolve path the same way SelvaLoraTrainer does (relative → ComfyUI output dir)."""
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = (
|
|
||||||
sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
)
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
"""Merge base defaults + file base + experiment overrides."""
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
# Don't carry id/description into the training params
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
start_step: int, total_steps: int) -> dict:
|
|
||||||
"""Build a dict of {step: loss} at each save_every boundary.
|
|
||||||
|
|
||||||
loss_history[i] = average loss over steps [start + i*log_interval + 1 …
|
|
||||||
start + (i+1)*log_interval].
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
targets = range(save_every, total_steps + 1, save_every)
|
|
||||||
for target in targets:
|
|
||||||
# index of the loss entry nearest to this step
|
|
||||||
idx = (target - start_step) // log_interval - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_comparison_curves(
|
|
||||||
experiments_data: list, # list of dicts: {id, loss_history, log_interval, start_step}
|
|
||||||
) -> Image.Image:
|
|
||||||
"""Draw all smoothed loss curves on the same axes, one color per experiment."""
|
|
||||||
W, H = 900, 420
|
|
||||||
pl, pr, pt, pb = 75, 160, 30, 50 # wider right margin for legend
|
|
||||||
|
|
||||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
|
|
||||||
pw = W - pl - pr
|
|
||||||
ph = H - pt - pb
|
|
||||||
|
|
||||||
# Collect all smoothed series
|
|
||||||
series = []
|
|
||||||
for i, ed in enumerate(experiments_data):
|
|
||||||
lh = ed.get("loss_history") or []
|
|
||||||
if len(lh) < 2:
|
|
||||||
continue
|
|
||||||
sm = _smooth_losses(lh)
|
|
||||||
series.append({
|
|
||||||
"id": ed["id"],
|
|
||||||
"smoothed": sm,
|
|
||||||
"log_interval": ed.get("log_interval", 50),
|
|
||||||
"start_step": ed.get("start_step", 0),
|
|
||||||
"color": _PALETTE[i % len(_PALETTE)],
|
|
||||||
})
|
|
||||||
|
|
||||||
if not series:
|
|
||||||
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
|
|
||||||
return img
|
|
||||||
|
|
||||||
all_vals = [v for s in series for v in s["smoothed"]]
|
|
||||||
lo, hi = min(all_vals), max(all_vals)
|
|
||||||
if hi == lo:
|
|
||||||
hi = lo + 1e-6
|
|
||||||
rng = hi - lo
|
|
||||||
|
|
||||||
# Horizontal grid + y-axis labels
|
|
||||||
for i in range(5):
|
|
||||||
y = pt + int(i * ph / 4)
|
|
||||||
val = hi - i * rng / 4
|
|
||||||
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
|
||||||
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
|
|
||||||
|
|
||||||
# Draw each curve
|
|
||||||
for s in series:
|
|
||||||
n = len(s["smoothed"])
|
|
||||||
pts = []
|
|
||||||
for j, v in enumerate(s["smoothed"]):
|
|
||||||
x = pl + int(j * pw / max(n - 1, 1))
|
|
||||||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
|
||||||
pts.append((x, y))
|
|
||||||
draw.line(pts, fill=s["color"], width=2)
|
|
||||||
|
|
||||||
# Axes
|
|
||||||
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.text((pl + 4, 8), "Loss comparison (smoothed)", fill=(40, 40, 40))
|
|
||||||
|
|
||||||
# Legend (right side)
|
|
||||||
lx = W - pr + 10
|
|
||||||
ly = pt
|
|
||||||
for s in series:
|
|
||||||
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
|
|
||||||
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
|
|
||||||
ly += 20
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaLoraScheduler:
|
|
||||||
"""Runs a sweep of LoRA training experiments defined in a JSON file.
|
|
||||||
|
|
||||||
The dataset (VAE encoding + .npz loading) is performed once and shared
|
|
||||||
across all experiments. Each experiment deep-copies the generator and trains
|
|
||||||
independently. Results are written to `experiment_summary.json` after every
|
|
||||||
completed run so partial results are preserved if the sweep is interrupted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — share this file to compare runs.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of LoRA training experiments defined in a JSON sweep file. "
|
|
||||||
"The dataset is encoded once and reused across all experiments. "
|
|
||||||
"Results (loss, config, adapter paths) are collected in experiment_summary.json."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"models directory; absolute paths are used as-is. "
|
|
||||||
"See LORA_TRAINING.md for the file format."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate the JSON file
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
# Try relative to ComfyUI models dir first, then output dir
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[LoRA Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[LoRA Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[LoRA Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[LoRA Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"lora_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
|
|
||||||
print(f"\n[LoRA Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[LoRA Scheduler] {description}", flush=True)
|
|
||||||
print(f"[LoRA Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[LoRA Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load + encode dataset once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(list(data_dir.glob("*.npz")))
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build or restore the summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = [] # collected for comparison image
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
lh = rec["results"].get("loss_history", [])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": lh,
|
|
||||||
"log_interval": rec["results"].get("log_interval", 50),
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
# Restore the original summary, clear completed_at so it gets set again
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[LoRA Scheduler] Resuming — skipping {len(completed_ids)} "
|
|
||||||
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[LoRA Scheduler] Could not read existing summary ({e}) — starting fresh",
|
|
||||||
flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": n_clips,
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Run each experiment
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
trainer = SelvaLoraTrainer()
|
|
||||||
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
|
||||||
log_interval = 50 # matches _train_inner
|
|
||||||
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
variant = model["variant"]
|
|
||||||
mode = model["mode"]
|
|
||||||
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[LoRA Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
# Required training params
|
|
||||||
steps = int(cfg.get("steps", 2000))
|
|
||||||
rank = int(cfg.get("rank", 16))
|
|
||||||
lr = float(cfg.get("lr", 1e-4))
|
|
||||||
alpha = float(cfg.get("alpha", 0.0))
|
|
||||||
target = str(cfg.get("target", "attn.qkv"))
|
|
||||||
batch_size = int(cfg.get("batch_size", 4))
|
|
||||||
warmup = int(cfg.get("warmup_steps", 100))
|
|
||||||
grad_accum = int(cfg.get("grad_accum", 1))
|
|
||||||
save_every = int(cfg.get("save_every", 500))
|
|
||||||
resume_path = str(cfg.get("resume_path", ""))
|
|
||||||
seed = int(cfg.get("seed", 42))
|
|
||||||
ts_mode = str(cfg.get("timestep_mode", "uniform"))
|
|
||||||
ln_sigma = float(cfg.get("logit_normal_sigma", 1.0))
|
|
||||||
curr_switch = float(cfg.get("curriculum_switch", 0.6))
|
|
||||||
dropout = float(cfg.get("lora_dropout", 0.0))
|
|
||||||
plus_ratio = float(cfg.get("lora_plus_ratio", 1.0))
|
|
||||||
lr_schedule = str(cfg.get("lr_schedule", "constant"))
|
|
||||||
init_mode = str(cfg.get("init_mode", "pissa"))
|
|
||||||
use_rslora = bool(cfg.get("use_rslora", True))
|
|
||||||
alpha_val = alpha if alpha > 0.0 else float(2 * rank)
|
|
||||||
target_suffixes = tuple(target.strip().split())
|
|
||||||
|
|
||||||
output_dir = output_root / exp_id
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[LoRA Scheduler] ── Experiment '{exp_id}' ──", flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[LoRA Scheduler] {exp_desc}", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"rank": rank, "alpha": alpha_val, "lr": lr, "steps": steps,
|
|
||||||
"batch_size": batch_size, "warmup_steps": warmup,
|
|
||||||
"grad_accum": grad_accum, "save_every": save_every,
|
|
||||||
"seed": seed, "target": list(target_suffixes),
|
|
||||||
"timestep_mode": ts_mode, "logit_normal_sigma": ln_sigma,
|
|
||||||
"curriculum_switch": curr_switch,
|
|
||||||
"lora_dropout": dropout, "lora_plus_ratio": plus_ratio,
|
|
||||||
"lr_schedule": lr_schedule,
|
|
||||||
"init_mode": init_mode, "use_rslora": use_rslora,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"adapter_path": None,
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = trainer._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, variant, mode,
|
|
||||||
data_dir, output_dir, steps, rank, lr,
|
|
||||||
alpha_val, target_suffixes, batch_size, warmup,
|
|
||||||
grad_accum, save_every, resume_path, seed,
|
|
||||||
ts_mode, ln_sigma, curr_switch, dropout, plus_ratio,
|
|
||||||
lr_schedule, init_mode, use_rslora,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
loss_history = r["loss_history"]
|
|
||||||
grad_norm_history = r.get("grad_norm_history", [])
|
|
||||||
spectral_metrics = r.get("spectral_metrics", {})
|
|
||||||
run_start_step = r.get("start_step", 0)
|
|
||||||
smoothed = _smooth_losses(loss_history) if loss_history else []
|
|
||||||
|
|
||||||
# Scalar summary metrics
|
|
||||||
final_loss = round(smoothed[-1], 6) if smoothed else None
|
|
||||||
min_loss = round(min(smoothed), 6) if smoothed else None
|
|
||||||
min_idx = smoothed.index(min(smoothed)) if smoothed else None
|
|
||||||
min_loss_step = (
|
|
||||||
run_start_step + (min_idx + 1) * log_interval
|
|
||||||
if min_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stability: std-dev of raw loss over last 25% of steps
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
last_q = loss_history[-quarter:]
|
|
||||||
loss_std_last_quarter = round(float(np.std(last_q)), 6)
|
|
||||||
else:
|
|
||||||
loss_std_last_quarter = None
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std_last_quarter,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval, save_every, run_start_step, steps
|
|
||||||
),
|
|
||||||
"loss_history": [round(v, 6) for v in loss_history],
|
|
||||||
"grad_norm_history": grad_norm_history,
|
|
||||||
"spectral_metrics": {str(k): v for k, v in spectral_metrics.items()},
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["adapter_path"] = r["adapter_path"]
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"start_step": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
except SkipExperiment as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' skipped: {e}", flush=True)
|
|
||||||
partial = getattr(e, "partial", {})
|
|
||||||
lh = partial.get("loss_history", [])
|
|
||||||
smoothed = _smooth_losses(lh) if lh else []
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "skipped",
|
|
||||||
"stopped_at_step": partial.get("stopped_at_step"),
|
|
||||||
"final_loss": round(smoothed[-1], 6) if smoothed else None,
|
|
||||||
"loss_history": [round(v, 6) for v in lh],
|
|
||||||
"grad_norm_history": partial.get("grad_norm_history", []),
|
|
||||||
"spectral_metrics": {str(k): v for k, v in partial.get("spectral_metrics", {}).items()},
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[LoRA Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
# Continue to next experiment rather than aborting the whole sweep
|
|
||||||
continue
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Finalise summary
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[LoRA Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Comparison image
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
comparison_img.save(str(output_root / "loss_comparison.png"))
|
|
||||||
comparison_tensor = _pil_to_tensor(comparison_img)
|
|
||||||
|
|
||||||
return (str(summary_path), comparison_tensor)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,171 +0,0 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
|
|
||||||
|
|
||||||
# Variant → (generator filename, mode, has_bigvgan)
|
|
||||||
_VARIANTS = {
|
|
||||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
|
||||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
|
||||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
|
||||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
|
||||||
}
|
|
||||||
|
|
||||||
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
|
||||||
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
|
|
||||||
|
|
||||||
|
|
||||||
_HF_REPO = "jnwnlee/SelVA"
|
|
||||||
|
|
||||||
# filename → (hf_repo_path, expected_md5 or None to skip check)
|
|
||||||
# Note: 44k generators are named 44khz in the HF repo; md5=None since the
|
|
||||||
# original download_utils had the wrong filenames so those md5s are unverified.
|
|
||||||
_WEIGHTS = {
|
|
||||||
"video_enc_sup_5.pth": ("weights/video_enc_sup_5.pth", "ff09a6dc36148536ee4db97eba081d05"),
|
|
||||||
"generator_small_16k_sup_5.pth": ("weights/generator_small_16k_sup_5.pth", "1cb0f0deec52de37f67b1fd9965337d0"),
|
|
||||||
"generator_small_44k_sup_5.pth": ("weights/generator_small_44khz_sup_5.pth", None),
|
|
||||||
"generator_medium_44k_sup_5.pth":("weights/generator_medium_44khz_sup_5.pth", None),
|
|
||||||
"generator_large_44k_sup_5.pth": ("weights/generator_large_44khz_sup_5.pth", None),
|
|
||||||
"v1-16.pth": ("ext_weights/v1-16.pth", "69f56803f59a549a1a507c93859fd4d7"),
|
|
||||||
"v1-44.pth": ("ext_weights/v1-44.pth", "fab020275fa44c6589820ce025191600"),
|
|
||||||
"best_netG.pt": ("ext_weights/best_netG.pt", "eeaf372a38a9c31c362120aba2dde292"),
|
|
||||||
"synchformer_state_dict.pth": ("ext_weights/synchformer_state_dict.pth", "5b2f5594b0730f70e41e549b7c94390c"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _md5(path):
|
|
||||||
import hashlib
|
|
||||||
h = hashlib.md5()
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
for chunk in iter(lambda: f.read(8 * 1024 * 1024), b""):
|
|
||||||
h.update(chunk)
|
|
||||||
return h.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure(filename, subdir=None):
|
|
||||||
"""Return path to weight file. Re-downloads if missing or MD5 mismatch."""
|
|
||||||
import shutil
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
|
|
||||||
dest_path = dest_dir / filename
|
|
||||||
|
|
||||||
entry = _WEIGHTS.get(filename)
|
|
||||||
if entry is None:
|
|
||||||
raise ValueError(f"[SelVA] Unknown weight file: {filename}")
|
|
||||||
repo_path, expected_md5 = entry
|
|
||||||
|
|
||||||
if dest_path.exists():
|
|
||||||
if expected_md5 is None:
|
|
||||||
return str(dest_path)
|
|
||||||
actual = _md5(dest_path)
|
|
||||||
if actual == expected_md5:
|
|
||||||
return str(dest_path)
|
|
||||||
print(f"[SelVA] {filename}: MD5 mismatch ({actual} ≠ {expected_md5}), re-downloading...", flush=True)
|
|
||||||
dest_path.unlink()
|
|
||||||
|
|
||||||
print(f"[SelVA] Downloading {filename} from {_HF_REPO}...", flush=True)
|
|
||||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
cached = hf_hub_download(repo_id=_HF_REPO, filename=repo_path)
|
|
||||||
shutil.copy2(cached, dest_path)
|
|
||||||
print(f"[SelVA] Saved to {dest_path}", flush=True)
|
|
||||||
return str(dest_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _synchformer_path():
|
|
||||||
"""Return synchformer path, reusing models/prismaudio/ if already present."""
|
|
||||||
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
|
|
||||||
if prismaudio_path.exists():
|
|
||||||
return str(prismaudio_path)
|
|
||||||
return _ensure("synchformer_state_dict.pth")
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaModelLoader:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"variant": (list(_VARIANTS.keys()), {
|
|
||||||
"tooltip": "Model size and output sample rate. small_16k is fastest (16 kHz). 44k variants output 44.1 kHz. larger = better quality, more VRAM.",
|
|
||||||
}),
|
|
||||||
"precision": (["bf16", "fp16", "fp32"], {
|
|
||||||
"tooltip": "Compute dtype. bf16 is recommended on Ampere+ GPUs. fp16 for older NVIDIA hardware. fp32 if you see NaN outputs.",
|
|
||||||
}),
|
|
||||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"], {
|
|
||||||
"tooltip": "auto picks keep_in_vram if ≥16 GB VRAM is free, otherwise offload_to_cpu. offload_to_cpu moves weights to RAM between nodes, saving VRAM at the cost of speed.",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("SELVA_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
|
|
||||||
FUNCTION = "load_model"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
|
|
||||||
|
|
||||||
def load_model(self, variant, precision, offload_strategy):
|
|
||||||
from selva_core.model.networks_generator import get_my_mmaudio
|
|
||||||
from selva_core.model.networks_video_enc import get_my_textsynch
|
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
|
||||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
|
||||||
|
|
||||||
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
|
|
||||||
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
|
|
||||||
precision = "fp16"
|
|
||||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
|
||||||
strategy = determine_offload_strategy(offload_strategy)
|
|
||||||
|
|
||||||
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
|
|
||||||
video_enc_path = _ensure("video_enc_sup_5.pth")
|
|
||||||
gen_path = _ensure(gen_filename)
|
|
||||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
|
||||||
vae_path = _ensure(vae_name, subdir="ext")
|
|
||||||
synch_path = _synchformer_path()
|
|
||||||
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
|
|
||||||
|
|
||||||
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
|
||||||
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
|
||||||
net_video_enc.load_weights(
|
|
||||||
torch.load(video_enc_path, map_location="cpu", weights_only=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
|
||||||
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
|
||||||
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
|
||||||
net_generator.load_weights(
|
|
||||||
torch.load(gen_path, map_location="cpu", weights_only=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
|
||||||
feature_utils = FeaturesUtils(
|
|
||||||
tod_vae_ckpt=vae_path,
|
|
||||||
synchformer_ckpt=synch_path,
|
|
||||||
enable_conditions=True,
|
|
||||||
mode=mode,
|
|
||||||
bigvgan_vocoder_ckpt=bigvgan_path,
|
|
||||||
need_vae_encoder=True,
|
|
||||||
).to(device, dtype).eval()
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(get_offload_device())
|
|
||||||
net_video_enc.to(get_offload_device())
|
|
||||||
feature_utils.to(get_offload_device())
|
|
||||||
|
|
||||||
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
|
||||||
|
|
||||||
return ({
|
|
||||||
"generator": net_generator,
|
|
||||||
"video_enc": net_video_enc,
|
|
||||||
"feature_utils": feature_utils,
|
|
||||||
"variant": variant,
|
|
||||||
"mode": mode,
|
|
||||||
"strategy": strategy,
|
|
||||||
"dtype": dtype,
|
|
||||||
"seq_cfg": seq_cfg,
|
|
||||||
},)
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
import torch
|
|
||||||
import comfy.utils
|
|
||||||
import comfy.model_management
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
|
||||||
from .selva_textual_inversion_trainer import _inject_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaSampler:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"features": ("SELVA_FEATURES",),
|
|
||||||
"prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": True,
|
|
||||||
"tooltip": "Sound description for CLIP text conditioning. Leave empty to reuse the prompt from the Feature Extractor (wire its prompt output here). Changing this without re-extracting features shifts CLIP conditioning but not sync features.",
|
|
||||||
}),
|
|
||||||
"negative_prompt": ("STRING", {
|
|
||||||
"default": "", "multiline": False,
|
|
||||||
"tooltip": "Sounds to suppress, e.g. 'speech, music, wind noise'. Steered away from via CFG. Leave empty for unconditional guidance baseline.",
|
|
||||||
}),
|
|
||||||
"duration": ("FLOAT", {
|
|
||||||
"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
|
||||||
"tooltip": "Output audio length in seconds. 0 = match the video duration stored in features.",
|
|
||||||
}),
|
|
||||||
"steps": ("INT", {"default": 25, "min": 1, "max": 200,
|
|
||||||
"tooltip": "Euler steps for the flow matching ODE. 25 is the SelVA default. Diminishing returns above 50; below 10 may sound rough."}),
|
|
||||||
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1,
|
|
||||||
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 3–7."}),
|
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"steering_vectors": ("STEERING_VECTORS", {
|
|
||||||
"tooltip": "Activation steering bundle from SelVA Activation Steering Loader. "
|
|
||||||
"Nudges each DiT block's hidden state toward the extracted pattern.",
|
|
||||||
}),
|
|
||||||
"steering_strength": ("FLOAT", {
|
|
||||||
"default": 0.1, "min": 0.0, "max": 2.0, "step": 0.05,
|
|
||||||
"tooltip": "Scale applied to each steering vector before adding to block output. "
|
|
||||||
"Start around 0.1–0.3; higher values risk destabilizing the ODE.",
|
|
||||||
}),
|
|
||||||
"normalize": ("BOOLEAN", {
|
|
||||||
"default": True,
|
|
||||||
"tooltip": "Normalize output level. Uses RMS normalization to target_lufs rather than peak normalization, so level matches typical audio content.",
|
|
||||||
}),
|
|
||||||
"target_lufs": ("FLOAT", {
|
|
||||||
"default": -27.0, "min": -40.0, "max": -6.0, "step": 1.0,
|
|
||||||
"tooltip": "Target RMS level in dBFS when normalize=True. -27 matches the measured RMS of LUFS-normalized training clips. Increase toward -20 for louder output.",
|
|
||||||
}),
|
|
||||||
"textual_inversion": ("TEXTUAL_INVERSION", {
|
|
||||||
"tooltip": "Learned token embeddings from SelVA Textual Inversion Loader. "
|
|
||||||
"Injects style tokens into CLIP conditioning without modifying model weights.",
|
|
||||||
}),
|
|
||||||
"ti_strength": ("FLOAT", {
|
|
||||||
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
|
||||||
"tooltip": "Blends between original CLIP conditioning (0.0) and full TI injection (1.0). "
|
|
||||||
"Reduce toward 0.3–0.5 if TI produces buzz artifacts.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
|
||||||
|
|
||||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, steering_vectors=None, steering_strength=0.1, normalize=True, target_lufs=-27.0, textual_inversion=None, ti_strength=1.0):
|
|
||||||
import dataclasses
|
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
net_generator = model["generator"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
|
|
||||||
# Validate that features were extracted with the same model variant
|
|
||||||
feat_variant = features.get("variant")
|
|
||||||
if feat_variant is not None and feat_variant != model["variant"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
|
|
||||||
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resolve prompt: use override if given, otherwise fall back to features prompt
|
|
||||||
if not prompt or not prompt.strip():
|
|
||||||
prompt = features.get("prompt", "")
|
|
||||||
if prompt:
|
|
||||||
print(f"[SelVA] Using prompt from features: '{prompt[:60]}'", flush=True)
|
|
||||||
else:
|
|
||||||
print("[SelVA] Warning: no prompt in features or sampler — CLIP text conditioning will be empty.", flush=True)
|
|
||||||
|
|
||||||
# Resolve duration
|
|
||||||
if duration <= 0:
|
|
||||||
if "duration" not in features:
|
|
||||||
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
|
||||||
duration = features["duration"]
|
|
||||||
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
|
||||||
|
|
||||||
# Derive sequence config for this duration from the model's mode template
|
|
||||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
|
||||||
sample_rate = seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(device)
|
|
||||||
feature_utils.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
try:
|
|
||||||
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
|
||||||
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
|
||||||
|
|
||||||
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
|
||||||
|
|
||||||
# Update model rotary position embeddings for actual feature shapes and duration.
|
|
||||||
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
|
|
||||||
net_generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=clip_f.shape[1],
|
|
||||||
sync_seq_len=sync_f.shape[1],
|
|
||||||
)
|
|
||||||
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Encode text conditioning
|
|
||||||
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
|
||||||
|
|
||||||
# Encode negative prompt (or use empty conditions)
|
|
||||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
|
||||||
if negative_prompt.strip() else None
|
|
||||||
|
|
||||||
# Inject textual inversion tokens into CLIP conditioning
|
|
||||||
if textual_inversion is not None:
|
|
||||||
emb = textual_inversion["embeddings"].to(device, dtype) # [K, 1024]
|
|
||||||
K = emb.shape[0]
|
|
||||||
inject_mode = textual_inversion.get("inject_mode", "suffix")
|
|
||||||
ti_text = _inject_tokens(text_clip, emb, K, inject_mode)
|
|
||||||
text_clip = torch.lerp(text_clip, ti_text, ti_strength)
|
|
||||||
if neg_text_clip is not None:
|
|
||||||
ti_neg = _inject_tokens(neg_text_clip, emb, K, inject_mode)
|
|
||||||
neg_text_clip = torch.lerp(neg_text_clip, ti_neg, ti_strength)
|
|
||||||
print(f"[SelVA] Textual inversion: {K} tokens mode={inject_mode} strength={ti_strength}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
|
||||||
empty_conditions = net_generator.get_empty_conditions(
|
|
||||||
bs=1, negative_text_features=neg_text_clip
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initial noise (MPS doesn't support torch.Generator on device)
|
|
||||||
gen_device = "cpu" if device.type == "mps" else device
|
|
||||||
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
|
||||||
x0 = torch.randn(
|
|
||||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
|
||||||
device=gen_device, dtype=dtype, generator=rng,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# Activation steering: apply only during the conditional predict_flow pass
|
|
||||||
# so steering gets amplified by cfg_strength rather than canceling out.
|
|
||||||
steering_handles = []
|
|
||||||
_orig_predict_flow = None
|
|
||||||
if steering_vectors is not None and steering_strength > 0.0:
|
|
||||||
vecs = steering_vectors["steering_vectors"]
|
|
||||||
n_joint = steering_vectors["n_joint"]
|
|
||||||
|
|
||||||
# Patch predict_flow to flag which pass is conditional.
|
|
||||||
# ode_wrapper calls predict_flow(conditions) and predict_flow(empty_conditions);
|
|
||||||
# identity check tells us which is which.
|
|
||||||
_is_cond_pass = [False]
|
|
||||||
_orig_predict_flow = net_generator.predict_flow
|
|
||||||
|
|
||||||
def _tracked_predict_flow(latent, t, cond):
|
|
||||||
_is_cond_pass[0] = (cond is conditions)
|
|
||||||
return _orig_predict_flow(latent, t, cond)
|
|
||||||
|
|
||||||
net_generator.predict_flow = _tracked_predict_flow
|
|
||||||
|
|
||||||
def _make_steering_hook(vec_cpu, is_joint, strength, dev, dt):
|
|
||||||
vec = vec_cpu.to(dev, dt) # [seq, hidden]
|
|
||||||
def hook(module, input, output):
|
|
||||||
if not _is_cond_pass[0]:
|
|
||||||
return # skip unconditional pass; steering amplified by cfg_strength
|
|
||||||
# Interpolate steering vec to match actual output seq length
|
|
||||||
# (handles generation at different duration than extraction)
|
|
||||||
if is_joint:
|
|
||||||
out_seq = output[0].shape[1]
|
|
||||||
else:
|
|
||||||
out_seq = output.shape[1]
|
|
||||||
v = vec
|
|
||||||
if v.shape[0] != out_seq:
|
|
||||||
v = torch.nn.functional.interpolate(
|
|
||||||
v.T.unsqueeze(0), # [1, hidden, seq_orig]
|
|
||||||
size=out_seq,
|
|
||||||
mode="linear",
|
|
||||||
align_corners=False,
|
|
||||||
).squeeze(0).T # [seq_new, hidden]
|
|
||||||
if is_joint:
|
|
||||||
latent_out = output[0] + strength * v
|
|
||||||
return (latent_out,) + output[1:]
|
|
||||||
else:
|
|
||||||
return output + strength * v
|
|
||||||
return hook
|
|
||||||
|
|
||||||
blocks = list(net_generator.joint_blocks) + list(net_generator.fused_blocks)
|
|
||||||
for i, block in enumerate(blocks):
|
|
||||||
is_joint = i < n_joint
|
|
||||||
if i < len(vecs):
|
|
||||||
h = block.register_forward_hook(
|
|
||||||
_make_steering_hook(vecs[i], is_joint, steering_strength, device, dtype)
|
|
||||||
)
|
|
||||||
steering_handles.append(h)
|
|
||||||
print(f"[SelVA] Activation steering: {len(steering_handles)} blocks "
|
|
||||||
f"strength={steering_strength} (conditional pass only)", flush=True)
|
|
||||||
|
|
||||||
# Flow matching ODE (Euler)
|
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
|
|
||||||
def ode_wrapper_tracked(t, x):
|
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
||||||
pbar.update(1)
|
|
||||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
|
||||||
|
|
||||||
try:
|
|
||||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
|
||||||
except torch.cuda.OutOfMemoryError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
|
||||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if _orig_predict_flow is not None:
|
|
||||||
net_generator.predict_flow = _orig_predict_flow
|
|
||||||
for h in steering_handles:
|
|
||||||
h.remove()
|
|
||||||
|
|
||||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
|
||||||
|
|
||||||
# Decode: latent → mel → audio
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
x1_unnorm = net_generator.unnormalize(x1)
|
|
||||||
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
|
|
||||||
audio = feature_utils.vocode(spec) # mel → waveform
|
|
||||||
except torch.cuda.OutOfMemoryError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
|
|
||||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
net_generator.to(get_offload_device())
|
|
||||||
feature_utils.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
# Ensure [1, 1, samples] and normalize to [-1,1]
|
|
||||||
audio = audio.float()
|
|
||||||
if audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
|
||||||
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
|
||||||
|
|
||||||
if normalize:
|
|
||||||
target_rms = 10 ** (target_lufs / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = audio * (target_rms / rms)
|
|
||||||
# If RMS normalization pushes peaks into clipping, scale back to
|
|
||||||
# preserve dynamics rather than hard-clipping (no saturation)
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
if peak > 1.0:
|
|
||||||
audio = audio / peak
|
|
||||||
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaSkipExperiment:
|
|
||||||
"""Writes skip_current.flag into a sweep output_root.
|
|
||||||
|
|
||||||
Queue this node while a SelVA LoRA Scheduler sweep is running to skip
|
|
||||||
the current experiment and move to the next one. The trainer picks up
|
|
||||||
the flag within 50 steps (~a few seconds).
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"output_root": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "output_root of the running sweep — same value as in your experiments JSON.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("STRING",)
|
|
||||||
RETURN_NAMES = ("flag_path",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Path where the flag was written.",)
|
|
||||||
FUNCTION = "skip"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Signals the running SelVA LoRA Scheduler to skip the current experiment "
|
|
||||||
"and move to the next one. Queue this node while the scheduler is running. "
|
|
||||||
"Partial scalars collected so far are saved in the summary."
|
|
||||||
)
|
|
||||||
|
|
||||||
def skip(self, output_root: str):
|
|
||||||
p = Path(output_root.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[SelVA Skip] output_root not found: {p}")
|
|
||||||
|
|
||||||
flag = p / "skip_current.flag"
|
|
||||||
flag.touch()
|
|
||||||
print(f"[SelVA Skip] Flag written: {flag}", flush=True)
|
|
||||||
return (str(flag),)
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Loader.
|
|
||||||
|
|
||||||
Loads a .pt file produced by SelvaTextualInversionTrainer and returns a
|
|
||||||
TEXTUAL_INVERSION bundle that the SelVA Sampler can inject into text conditioning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaTextualInversionLoader:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"path": ("STRING", {
|
|
||||||
"default": "textual_inversion.pt",
|
|
||||||
"tooltip": "Path to a .pt file produced by SelVA Textual Inversion Trainer. "
|
|
||||||
"Relative paths resolve to the ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("TEXTUAL_INVERSION",)
|
|
||||||
RETURN_NAMES = ("textual_inversion",)
|
|
||||||
OUTPUT_TOOLTIPS = ("Learned token embeddings — connect to SelVA Sampler's textual_inversion input.",)
|
|
||||||
FUNCTION = "load"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Loads learned CLIP token embeddings produced by SelVA Textual Inversion Trainer. "
|
|
||||||
"Connect the output to the SelVA Sampler's optional textual_inversion input to guide "
|
|
||||||
"generation toward the training data style without degrading audio quality."
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self, path: str) -> tuple:
|
|
||||||
p = Path(path.strip())
|
|
||||||
if not p.is_absolute():
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p
|
|
||||||
if not p.exists():
|
|
||||||
raise FileNotFoundError(f"[TI Loader] File not found: {p}")
|
|
||||||
|
|
||||||
data = torch.load(str(p), map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
embeddings = data["embeddings"] # [K, 1024]
|
|
||||||
n_tokens = int(data.get("n_tokens", embeddings.shape[0]))
|
|
||||||
|
|
||||||
print(f"[TI Loader] Loaded '{p.name}' n_tokens={n_tokens} "
|
|
||||||
f"shape={tuple(embeddings.shape)}", flush=True)
|
|
||||||
if data.get("init_text"):
|
|
||||||
print(f"[TI Loader] init_text='{data['init_text']}'", flush=True)
|
|
||||||
if data.get("step"):
|
|
||||||
print(f"[TI Loader] trained {data['step']} / {data.get('steps', '?')} steps "
|
|
||||||
f"lr={data.get('lr', '?')}", flush=True)
|
|
||||||
|
|
||||||
inject_mode = data.get("inject_mode", "suffix")
|
|
||||||
print(f"[TI Loader] inject_mode='{inject_mode}'", flush=True)
|
|
||||||
|
|
||||||
bundle = {
|
|
||||||
"embeddings": embeddings, # [K, 1024] float32 on CPU
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"path": str(p),
|
|
||||||
"init_text": data.get("init_text", ""),
|
|
||||||
}
|
|
||||||
return (bundle,)
|
|
||||||
@@ -1,450 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Trainer.
|
|
||||||
|
|
||||||
Learns K token embedding vectors in CLIP space that guide the base model
|
|
||||||
to generate audio in the style of the training clips — without modifying
|
|
||||||
any model weights.
|
|
||||||
|
|
||||||
Key difference from LoRA:
|
|
||||||
- ALL generator parameters are frozen (requires_grad=False)
|
|
||||||
- Only K×1024 token embeddings receive gradients
|
|
||||||
- Latents stay on the decoder's natural manifold → no quality degradation
|
|
||||||
- The learned tokens shift WHICH latents are generated, not HOW
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
1. Train on your .npz audio features
|
|
||||||
2. Load result with SelVA Textual Inversion Loader
|
|
||||||
3. Connect to SelVA Sampler optional input
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import random
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
from selva_core.model.flow_matching import FlowMatching
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_eval_sample,
|
|
||||||
_spectral_metrics,
|
|
||||||
_save_spectrogram,
|
|
||||||
_smooth_losses,
|
|
||||||
_draw_loss_curve,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Eval helper with token injection
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _inject_tokens(text_clip: torch.Tensor, tokens: torch.Tensor,
|
|
||||||
n_tokens: int, inject_mode: str) -> torch.Tensor:
|
|
||||||
"""Build a text_clip tensor with learned tokens injected.
|
|
||||||
|
|
||||||
inject_mode:
|
|
||||||
"suffix" — replace last n_tokens positions (EOS/padding zone)
|
|
||||||
"prefix" — replace positions 1:1+n_tokens (after BOS, before content)
|
|
||||||
|
|
||||||
Always uses torch.cat so gradient flows to `tokens` when tokens.requires_grad.
|
|
||||||
Works for both training (tokens is a Parameter) and eval (tokens is detached).
|
|
||||||
"""
|
|
||||||
if inject_mode == "prefix":
|
|
||||||
bos = text_clip[:, :1, :].detach() # [B, 1, D]
|
|
||||||
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
|
||||||
rest = text_clip[:, 1 + n_tokens:, :].detach() # [B, 75-K, D]
|
|
||||||
return torch.cat([bos, toks, rest], dim=1) # [B, 77, D]
|
|
||||||
else: # suffix (default)
|
|
||||||
front = text_clip[:, :-n_tokens, :].detach() # [B, 77-K, D]
|
|
||||||
toks = tokens.unsqueeze(0).expand(text_clip.shape[0], -1, -1) # [B, K, D]
|
|
||||||
return torch.cat([front, toks], dim=1) # [B, 77, D]
|
|
||||||
|
|
||||||
|
|
||||||
def _eval_sample_ti(generator, learned_tokens, n_tokens, inject_mode,
|
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
|
||||||
device, dtype, num_steps=25, seed=42, clip_idx=0):
|
|
||||||
"""Inference pass with learned tokens injected into text conditioning."""
|
|
||||||
generator.eval()
|
|
||||||
try:
|
|
||||||
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = dataset[clip_idx]
|
|
||||||
clip_f = clip_f_cpu.to(device, dtype)
|
|
||||||
sync_f = sync_f_cpu.to(device, dtype)
|
|
||||||
text_clip = text_clip_cpu.to(device, dtype).clone()
|
|
||||||
|
|
||||||
emb = learned_tokens.detach().to(device, dtype)
|
|
||||||
text_input = _inject_tokens(text_clip, emb, n_tokens, inject_mode)
|
|
||||||
|
|
||||||
rng = torch.Generator(device=device).manual_seed(seed)
|
|
||||||
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
|
||||||
device=device, dtype=dtype, generator=rng)
|
|
||||||
|
|
||||||
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
|
||||||
|
|
||||||
def velocity_fn(t, x):
|
|
||||||
return generator.forward(x, clip_f, sync_f, text_input,
|
|
||||||
t.reshape(1).to(device, dtype))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
|
||||||
x1_unnorm = generator.unnormalize(x1_pred)
|
|
||||||
|
|
||||||
tod = feature_utils_orig.tod
|
|
||||||
tod_orig_dev = next(tod.parameters()).device
|
|
||||||
tod.to(device)
|
|
||||||
try:
|
|
||||||
spec = feature_utils_orig.decode(x1_unnorm)
|
|
||||||
audio = feature_utils_orig.vocode(spec)
|
|
||||||
finally:
|
|
||||||
tod.to(tod_orig_dev)
|
|
||||||
|
|
||||||
audio = audio.float().cpu()
|
|
||||||
if audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(1)
|
|
||||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
|
||||||
audio = audio.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
target_rms = 10 ** (-27.0 / 20.0)
|
|
||||||
rms = audio.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
audio = (audio * (target_rms / rms))
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
if peak > 1.0:
|
|
||||||
audio = audio / peak
|
|
||||||
return audio.squeeze(0), seq_cfg.sampling_rate
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Trainer] Eval sample failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
return None, None
|
|
||||||
finally:
|
|
||||||
generator.train()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Node
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class SelvaTextualInversionTrainer:
|
|
||||||
"""Learns K CLIP token embeddings that steer SelVA toward a target audio style.
|
|
||||||
|
|
||||||
Unlike LoRA, all model weights are frozen. Only the K×1024 embedding tensor
|
|
||||||
receives gradients, keeping generated latents on the decoder's natural manifold
|
|
||||||
and preserving base model audio quality while shifting generation style.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "train"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("embeddings_path", "loss_curve")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to saved .pt embeddings — load with SelVA Textual Inversion Loader.",
|
|
||||||
"Smoothed training loss curve.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Trains K learnable CLIP token embeddings against your audio dataset "
|
|
||||||
"with all model weights frozen. The tokens are then injected into the "
|
|
||||||
"sampler to guide generation toward the training data style without "
|
|
||||||
"degrading audio quality."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"data_dir": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Directory containing .npz feature files and paired audio files (same as LoRA trainer).",
|
|
||||||
}),
|
|
||||||
"output_path": ("STRING", {
|
|
||||||
"default": "textual_inversion.pt",
|
|
||||||
"tooltip": "Where to save the learned embeddings. Relative paths resolve to ComfyUI output directory.",
|
|
||||||
}),
|
|
||||||
"n_tokens": ("INT", {
|
|
||||||
"default": 4, "min": 1, "max": 16,
|
|
||||||
"tooltip": "Number of learnable token vectors. More tokens = more expressive but slower to train. 4 is a good default.",
|
|
||||||
}),
|
|
||||||
"steps": ("INT", {
|
|
||||||
"default": 3000, "min": 100, "max": 50000,
|
|
||||||
"tooltip": "Training steps. 3000 is a reasonable starting point.",
|
|
||||||
}),
|
|
||||||
"lr": ("FLOAT", {
|
|
||||||
"default": 2e-4, "min": 1e-5, "max": 1e-1, "step": 1e-5,
|
|
||||||
"tooltip": "Learning rate. 2e-4 matches the LoRA working regime. Higher LR (1e-3) causes token norm to drift without plateauing on small datasets.",
|
|
||||||
}),
|
|
||||||
"batch_size": ("INT", {
|
|
||||||
"default": 4, "min": 1, "max": 64,
|
|
||||||
"tooltip": "Clips sampled per training step. Smaller batch (4–8) gives more diverse gradients and helps token norm saturate rather than drift.",
|
|
||||||
}),
|
|
||||||
"seed": ("INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
"save_every": ("INT", {
|
|
||||||
"default": 1000, "min": 100, "max": 10000,
|
|
||||||
"tooltip": "Save a checkpoint and generate an eval sample every N steps.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"inject_mode": (["suffix", "prefix"], {
|
|
||||||
"default": "suffix",
|
|
||||||
"tooltip": (
|
|
||||||
"Where to inject the learned tokens in the 77-token CLIP sequence. "
|
|
||||||
"'suffix' replaces the last K positions (EOS/padding — may be ignored by the model). "
|
|
||||||
"'prefix' replaces positions 1:1+K right after BOS — higher attention weight, stronger style signal."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
"init_text": ("STRING", {
|
|
||||||
"default": "",
|
|
||||||
"tooltip": "Optional text phrase to warm-start token values via CLIP. Leave empty for random init (N(0, 0.02)). Example: 'industrial sound design'.",
|
|
||||||
}),
|
|
||||||
"warmup_steps": ("INT", {
|
|
||||||
"default": 100, "min": 0, "max": 1000,
|
|
||||||
"tooltip": "Linear LR warmup steps.",
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def train(self, model, data_dir, output_path, n_tokens, steps, lr,
|
|
||||||
batch_size, seed, save_every,
|
|
||||||
inject_mode="suffix", init_text="", warmup_steps=100):
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
|
|
||||||
# --- Resolve paths ---
|
|
||||||
data_dir = Path(data_dir.strip())
|
|
||||||
if not data_dir.is_absolute():
|
|
||||||
data_dir = Path(folder_paths.models_dir) / data_dir
|
|
||||||
if not data_dir.exists():
|
|
||||||
raise FileNotFoundError(f"[TI Trainer] data_dir not found: {data_dir}")
|
|
||||||
|
|
||||||
out_path = Path(output_path.strip())
|
|
||||||
if not out_path.is_absolute():
|
|
||||||
out_path = Path(folder_paths.get_output_directory()) / out_path
|
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"\n[TI Trainer] n_tokens={n_tokens} steps={steps} lr={lr:.2e}", flush=True)
|
|
||||||
print(f"[TI Trainer] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[TI Trainer] output = {out_path}\n", flush=True)
|
|
||||||
|
|
||||||
# --- Load dataset (reuse LoRA trainer helper) ---
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# Training must run outside inference_mode so autograd works
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = self._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup_steps, seed, save_every, init_text, inject_mode,
|
|
||||||
)
|
|
||||||
smoothed = _smooth_losses(r["loss_history"]) if r["loss_history"] else []
|
|
||||||
curve_img = _draw_loss_curve(r["loss_history"], log_interval=50, smoothed=smoothed)
|
|
||||||
return (r["embeddings_path"], _pil_to_tensor(curve_img))
|
|
||||||
|
|
||||||
def _train_inner(
|
|
||||||
self, model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup_steps, seed, save_every, init_text, inject_mode="suffix",
|
|
||||||
):
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
# --- Generator (frozen) ---
|
|
||||||
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
|
||||||
generator.requires_grad_(False)
|
|
||||||
generator.update_seq_lengths(
|
|
||||||
latent_seq_len=seq_cfg.latent_seq_len,
|
|
||||||
clip_seq_len=seq_cfg.clip_seq_len,
|
|
||||||
sync_seq_len=seq_cfg.sync_seq_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Init learned tokens ---
|
|
||||||
# Call encode_text_clip outside the grad context (it has @inference_mode),
|
|
||||||
# grab values only (no grad needed), then wrap as nn.Parameter.
|
|
||||||
if init_text.strip():
|
|
||||||
with torch.no_grad():
|
|
||||||
init_embed = feature_utils_orig.encode_text_clip([init_text.strip()])
|
|
||||||
# Positions 1:1+n_tokens — after BOS, before EOS — have actual content
|
|
||||||
init_vals = init_embed[0, 1:1 + n_tokens, :].detach().clone().float()
|
|
||||||
if init_vals.shape[0] < n_tokens:
|
|
||||||
# Prompt was very short; pad remaining with small noise
|
|
||||||
pad = torch.randn(n_tokens - init_vals.shape[0], init_vals.shape[1]) * 0.02
|
|
||||||
init_vals = torch.cat([init_vals, pad], dim=0)
|
|
||||||
learned_tokens = torch.nn.Parameter(init_vals.to(device, dtype))
|
|
||||||
print(f"[TI Trainer] Init from '{init_text.strip()}' (positions 1–{n_tokens})", flush=True)
|
|
||||||
else:
|
|
||||||
learned_tokens = torch.nn.Parameter(
|
|
||||||
torch.randn(n_tokens, 1024, device=device, dtype=dtype) * 0.02
|
|
||||||
)
|
|
||||||
print(f"[TI Trainer] Init: random N(0, 0.02)", flush=True)
|
|
||||||
|
|
||||||
# --- Measure CLIP token norm from the dataset (content positions 1–20) ---
|
|
||||||
# Learned tokens must stay within this range or the model treats them as
|
|
||||||
# out-of-distribution and produces buzz artifacts instead of style shift.
|
|
||||||
with torch.no_grad():
|
|
||||||
sample_norms = []
|
|
||||||
for item in dataset[:min(len(dataset), 20)]:
|
|
||||||
tc = item[3].squeeze(0) # [77, 1024]
|
|
||||||
sample_norms.append(tc[1:20].norm(dim=-1)) # skip BOS/EOS
|
|
||||||
clip_norm_ref = torch.cat(sample_norms).mean().item()
|
|
||||||
clip_norm_limit = clip_norm_ref * 1.5 # 50% headroom above real tokens
|
|
||||||
print(f"[TI Trainer] CLIP token norm ref={clip_norm_ref:.4f} "
|
|
||||||
f"limit={clip_norm_limit:.4f}", flush=True)
|
|
||||||
|
|
||||||
# --- Optimizer + scheduler ---
|
|
||||||
optimizer = torch.optim.AdamW([learned_tokens], lr=lr, weight_decay=1e-2)
|
|
||||||
|
|
||||||
def lr_lambda(s):
|
|
||||||
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
||||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
|
||||||
|
|
||||||
# --- Checkpoint dir ---
|
|
||||||
ckpt_dir = out_path.parent / out_path.stem
|
|
||||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# --- Baseline sample (once, before any training) ---
|
|
||||||
print(f"[TI Trainer] Generating baseline sample...", flush=True)
|
|
||||||
baseline_wav, baseline_sr = _eval_sample(
|
|
||||||
generator, feature_utils_orig, dataset, seq_cfg, device, dtype, seed=seed,
|
|
||||||
)
|
|
||||||
if baseline_wav is not None:
|
|
||||||
baseline_path = ckpt_dir / "baseline.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(baseline_path), baseline_wav, baseline_sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(baseline_path), baseline_wav.squeeze(0).numpy(), baseline_sr)
|
|
||||||
try:
|
|
||||||
_save_spectrogram(baseline_wav, baseline_sr, ckpt_dir / "baseline.png")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
print(f"[TI Trainer] Baseline saved: {baseline_path}", flush=True)
|
|
||||||
|
|
||||||
# --- Training loop ---
|
|
||||||
generator.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
log_interval = 50
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
loss_history = []
|
|
||||||
running_loss = 0.0
|
|
||||||
|
|
||||||
print(f"[TI Trainer] Training {steps} steps batch_size={batch_size}\n", flush=True)
|
|
||||||
|
|
||||||
for step in range(1, steps + 1):
|
|
||||||
batch = random.choices(dataset, k=batch_size)
|
|
||||||
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
|
||||||
|
|
||||||
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
|
||||||
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
|
||||||
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
|
||||||
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype).clone()
|
|
||||||
|
|
||||||
# Inject learned tokens — gradient flows via torch.cat (not in-place assignment).
|
|
||||||
text_input = _inject_tokens(text_clip, learned_tokens, n_tokens, inject_mode)
|
|
||||||
|
|
||||||
x1 = generator.normalize(x1)
|
|
||||||
t = torch.rand(batch_size, device=device, dtype=dtype)
|
|
||||||
x0 = torch.randn_like(x1)
|
|
||||||
xt = fm.get_conditional_flow(x0, x1, t)
|
|
||||||
|
|
||||||
v_pred = generator.forward(xt, clip_f, sync_f, text_input, t)
|
|
||||||
loss = fm.loss(v_pred, x0, x1).mean()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_([learned_tokens], max_norm=1.0)
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
# Clamp token norm to CLIP manifold — prevents out-of-distribution
|
|
||||||
# embeddings that cause buzz artifacts instead of style shift.
|
|
||||||
with torch.no_grad():
|
|
||||||
norms = learned_tokens.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
|
||||||
scale = (clip_norm_limit / norms).clamp(max=1.0)
|
|
||||||
learned_tokens.data.mul_(scale)
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
if step % log_interval == 0:
|
|
||||||
avg = running_loss / log_interval
|
|
||||||
loss_history.append(round(avg, 6))
|
|
||||||
running_loss = 0.0
|
|
||||||
lr_now = scheduler.get_last_lr()[0]
|
|
||||||
norm = learned_tokens.norm(dim=-1).mean().item()
|
|
||||||
print(f"[TI Trainer] step {step:5d}/{steps} "
|
|
||||||
f"loss={avg:.4f} lr={lr_now:.2e} "
|
|
||||||
f"token_norm={norm:.4f}/{clip_norm_limit:.4f}", flush=True)
|
|
||||||
|
|
||||||
if step % save_every == 0 or step == steps:
|
|
||||||
# Save checkpoint
|
|
||||||
ckpt = {
|
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"step": step,
|
|
||||||
"init_text": init_text,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
ckpt_path = ckpt_dir / f"step_{step:05d}.pt"
|
|
||||||
torch.save(ckpt, str(ckpt_path))
|
|
||||||
|
|
||||||
# Eval sample
|
|
||||||
wav, sr = _eval_sample_ti(
|
|
||||||
generator, learned_tokens, n_tokens, inject_mode,
|
|
||||||
feature_utils_orig, dataset, seq_cfg,
|
|
||||||
device, dtype, seed=seed,
|
|
||||||
)
|
|
||||||
if wav is not None:
|
|
||||||
wav_path = ckpt_dir / f"step_{step:05d}.wav"
|
|
||||||
try:
|
|
||||||
torchaudio.save(str(wav_path), wav, sr)
|
|
||||||
except RuntimeError:
|
|
||||||
import soundfile as sf
|
|
||||||
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
|
||||||
|
|
||||||
try:
|
|
||||||
metrics = _spectral_metrics(wav, sr)
|
|
||||||
_save_spectrogram(wav, sr, ckpt_dir / f"step_{step:05d}.png")
|
|
||||||
print(f"[TI Trainer] step {step} "
|
|
||||||
f"centroid={metrics['spectral_centroid_hz']:.0f}Hz "
|
|
||||||
f"flatness={metrics['spectral_flatness']:.4f} "
|
|
||||||
f"hf={metrics['hf_energy_ratio']:.3f}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Trainer] Spectral/spectrogram failed: {e}", flush=True)
|
|
||||||
|
|
||||||
print(f"[TI Trainer] Checkpoint: {ckpt_path}", flush=True)
|
|
||||||
|
|
||||||
# --- Final save ---
|
|
||||||
final = {
|
|
||||||
"embeddings": learned_tokens.detach().cpu(),
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
"step": steps,
|
|
||||||
"init_text": init_text,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
torch.save(final, str(out_path))
|
|
||||||
print(f"\n[TI Trainer] Done. Saved: {out_path}", flush=True)
|
|
||||||
|
|
||||||
soft_empty_cache()
|
|
||||||
return {
|
|
||||||
"embeddings_path": str(out_path),
|
|
||||||
"loss_history": loss_history,
|
|
||||||
}
|
|
||||||
@@ -1,479 +0,0 @@
|
|||||||
"""SelVA Textual Inversion Scheduler — sweeps TI training experiments from a JSON file.
|
|
||||||
|
|
||||||
Each experiment inherits from a shared `base` config and overrides specific keys.
|
|
||||||
The dataset is loaded once and reused across all experiments. Results are written
|
|
||||||
to `experiment_summary.json` (updated after each completed run) and a comparison
|
|
||||||
loss-curve image showing all runs on the same axes.
|
|
||||||
|
|
||||||
JSON format:
|
|
||||||
{
|
|
||||||
"name": "ti_sweep_1",
|
|
||||||
"description": "optional human note",
|
|
||||||
"data_dir": "dataset/bj_sounds",
|
|
||||||
"output_root": "ti_output/sweep_1",
|
|
||||||
"base": {
|
|
||||||
"n_tokens": 4,
|
|
||||||
"lr": 1e-3,
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 16,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"seed": 42,
|
|
||||||
"save_every": 1000
|
|
||||||
},
|
|
||||||
"experiments": [
|
|
||||||
{"id": "baseline", "description": "default 4 tokens"},
|
|
||||||
{"id": "n8_tokens", "n_tokens": 8},
|
|
||||||
{"id": "lr_5e4", "lr": 5e-4},
|
|
||||||
{"id": "warm_init", "init_text": "industrial sound design"},
|
|
||||||
{"id": "n4_more_steps", "steps": 5000}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.utils
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device
|
|
||||||
from .selva_lora_trainer import (
|
|
||||||
_prepare_dataset,
|
|
||||||
_smooth_losses,
|
|
||||||
_pil_to_tensor,
|
|
||||||
)
|
|
||||||
from .selva_textual_inversion_trainer import SelvaTextualInversionTrainer
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers (shared with LoRA scheduler, inlined to keep modules independent)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_system_info() -> dict:
|
|
||||||
info: dict = {
|
|
||||||
"torch_version": torch.__version__,
|
|
||||||
"cuda_version": torch.version.cuda or "N/A",
|
|
||||||
"gpu_name": None,
|
|
||||||
"gpu_vram_gb": None,
|
|
||||||
}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
info["gpu_name"] = torch.cuda.get_device_name(0)
|
|
||||||
props = torch.cuda.get_device_properties(0)
|
|
||||||
info["gpu_vram_gb"] = round(props.total_memory / 1e9, 1)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
_PARAM_DEFAULTS = {
|
|
||||||
"n_tokens": 4,
|
|
||||||
"lr": 2e-4,
|
|
||||||
"steps": 3000,
|
|
||||||
"batch_size": 4,
|
|
||||||
"warmup_steps": 100,
|
|
||||||
"seed": 42,
|
|
||||||
"save_every": 1000,
|
|
||||||
"init_text": "",
|
|
||||||
"inject_mode": "suffix",
|
|
||||||
}
|
|
||||||
|
|
||||||
_PALETTE = [
|
|
||||||
(66, 133, 244),
|
|
||||||
(234, 67, 53),
|
|
||||||
(52, 168, 83),
|
|
||||||
(251, 188, 5),
|
|
||||||
(155, 89, 182),
|
|
||||||
(26, 188, 156),
|
|
||||||
(230, 126, 34),
|
|
||||||
(149, 165, 166),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(raw: str) -> Path:
|
|
||||||
p = Path(raw.strip())
|
|
||||||
unix_style_on_windows = (
|
|
||||||
sys.platform == "win32" and p.is_absolute() and not p.drive
|
|
||||||
)
|
|
||||||
if not p.is_absolute() or unix_style_on_windows:
|
|
||||||
p = Path(folder_paths.get_output_directory()) / p.relative_to(p.anchor)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_config(base: dict, experiment: dict) -> dict:
|
|
||||||
cfg = dict(_PARAM_DEFAULTS)
|
|
||||||
cfg.update(base)
|
|
||||||
cfg.update({k: v for k, v in experiment.items() if k not in ("id", "description")})
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
def _loss_at_steps(loss_history: list, log_interval: int, save_every: int,
|
|
||||||
total_steps: int) -> dict:
|
|
||||||
result = {}
|
|
||||||
for target in range(save_every, total_steps + 1, save_every):
|
|
||||||
idx = target // log_interval - 1
|
|
||||||
if 0 <= idx < len(loss_history):
|
|
||||||
result[str(target)] = round(loss_history[idx], 6)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _draw_comparison_curves(experiments_data: list) -> "Image.Image":
|
|
||||||
from PIL import Image, ImageDraw
|
|
||||||
|
|
||||||
W, H = 900, 420
|
|
||||||
pl, pr, pt, pb = 75, 160, 30, 50
|
|
||||||
|
|
||||||
img = Image.new("RGB", (W, H), (255, 255, 255))
|
|
||||||
draw = ImageDraw.Draw(img)
|
|
||||||
pw = W - pl - pr
|
|
||||||
ph = H - pt - pb
|
|
||||||
|
|
||||||
series = []
|
|
||||||
for i, ed in enumerate(experiments_data):
|
|
||||||
lh = ed.get("loss_history") or []
|
|
||||||
if len(lh) < 2:
|
|
||||||
continue
|
|
||||||
sm = _smooth_losses(lh)
|
|
||||||
series.append({
|
|
||||||
"id": ed["id"],
|
|
||||||
"smoothed": sm,
|
|
||||||
"color": _PALETTE[i % len(_PALETTE)],
|
|
||||||
})
|
|
||||||
|
|
||||||
if not series:
|
|
||||||
draw.text((pl + 10, pt + 10), "No data to plot", fill=(80, 80, 80))
|
|
||||||
return img
|
|
||||||
|
|
||||||
all_vals = [v for s in series for v in s["smoothed"]]
|
|
||||||
lo, hi = min(all_vals), max(all_vals)
|
|
||||||
if hi == lo:
|
|
||||||
hi = lo + 1e-6
|
|
||||||
rng = hi - lo
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
y = pt + int(i * ph / 4)
|
|
||||||
val = hi - i * rng / 4
|
|
||||||
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
|
||||||
draw.text((2, y - 7), f"{val:.4f}", fill=(100, 100, 100))
|
|
||||||
|
|
||||||
for s in series:
|
|
||||||
n = len(s["smoothed"])
|
|
||||||
pts = []
|
|
||||||
for j, v in enumerate(s["smoothed"]):
|
|
||||||
x = pl + int(j * pw / max(n - 1, 1))
|
|
||||||
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
|
||||||
pts.append((x, y))
|
|
||||||
draw.line(pts, fill=s["color"], width=2)
|
|
||||||
|
|
||||||
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
|
||||||
draw.text((pl + 4, 8), "TI loss comparison (smoothed)", fill=(40, 40, 40))
|
|
||||||
|
|
||||||
lx, ly = W - pr + 10, pt
|
|
||||||
for s in series:
|
|
||||||
draw.rectangle([(lx, ly + 3), (lx + 14, ly + 13)], fill=s["color"])
|
|
||||||
draw.text((lx + 18, ly), s["id"][:20], fill=(40, 40, 40))
|
|
||||||
ly += 20
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Node
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class SelvaTiScheduler:
|
|
||||||
"""Runs a sweep of Textual Inversion experiments defined in a JSON file.
|
|
||||||
|
|
||||||
The dataset is loaded once and reused. Each experiment calls
|
|
||||||
SelvaTextualInversionTrainer._train_inner() with its own config.
|
|
||||||
Results are written to experiment_summary.json after every completed run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
FUNCTION = "run"
|
|
||||||
RETURN_TYPES = ("STRING", "IMAGE")
|
|
||||||
RETURN_NAMES = ("summary_path", "comparison_curves")
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Path to experiment_summary.json — compare runs across sweeps.",
|
|
||||||
"All smoothed loss curves overlaid on the same axes.",
|
|
||||||
)
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Runs a series of Textual Inversion experiments from a JSON sweep file. "
|
|
||||||
"The dataset is encoded once and reused. Results (loss, config, embeddings "
|
|
||||||
"paths) are collected in experiment_summary.json after each run."
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"experiments_file": ("STRING", {
|
|
||||||
"default": "ti_experiments.json",
|
|
||||||
"tooltip": (
|
|
||||||
"Path to JSON sweep file. Relative paths resolve to the ComfyUI "
|
|
||||||
"output directory. See node description for the file format."
|
|
||||||
),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def run(self, model, experiments_file):
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Read + validate JSON
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
exp_path = Path(experiments_file.strip())
|
|
||||||
if not exp_path.is_absolute():
|
|
||||||
candidate = Path(folder_paths.models_dir) / exp_path
|
|
||||||
if not candidate.exists():
|
|
||||||
candidate = Path(folder_paths.get_output_directory()) / exp_path
|
|
||||||
exp_path = candidate
|
|
||||||
if not exp_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[TI Scheduler] Experiment file not found: {exp_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = json.loads(exp_path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
if "experiments" not in spec or not spec["experiments"]:
|
|
||||||
raise ValueError("[TI Scheduler] 'experiments' list is missing or empty.")
|
|
||||||
for i, exp in enumerate(spec["experiments"]):
|
|
||||||
if "id" not in exp:
|
|
||||||
raise ValueError(
|
|
||||||
f"[TI Scheduler] Experiment at index {i} is missing required 'id' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
sweep_name = spec.get("name", exp_path.stem)
|
|
||||||
description = spec.get("description", "")
|
|
||||||
base_cfg = spec.get("base", {})
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. Resolve data_dir and output_root
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if "data_dir" not in spec:
|
|
||||||
raise ValueError("[TI Scheduler] 'data_dir' is required in the sweep file.")
|
|
||||||
data_dir = _resolve_path(spec["data_dir"])
|
|
||||||
output_root = _resolve_path(spec.get("output_root", f"ti_sweeps/{sweep_name}"))
|
|
||||||
output_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
feature_utils_orig = model["feature_utils"]
|
|
||||||
|
|
||||||
print(f"\n[TI Scheduler] Sweep '{sweep_name}': "
|
|
||||||
f"{len(spec['experiments'])} experiment(s)", flush=True)
|
|
||||||
if description:
|
|
||||||
print(f"[TI Scheduler] {description}", flush=True)
|
|
||||||
print(f"[TI Scheduler] data_dir = {data_dir}", flush=True)
|
|
||||||
print(f"[TI Scheduler] output_root = {output_root}\n", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. Load dataset once
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
n_clips = len(list(data_dir.glob("*.npz")))
|
|
||||||
dataset = _prepare_dataset(model, data_dir, device)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. Build or restore summary (resume-aware)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary_path = output_root / "experiment_summary.json"
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
|
|
||||||
if summary_path.exists():
|
|
||||||
try:
|
|
||||||
existing = json.loads(summary_path.read_text(encoding="utf-8"))
|
|
||||||
for rec in existing.get("experiments", []):
|
|
||||||
if rec.get("results", {}).get("status") == "completed":
|
|
||||||
completed_ids.add(rec["id"])
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": rec["id"],
|
|
||||||
"loss_history": rec["results"].get("loss_history", []),
|
|
||||||
})
|
|
||||||
summary = existing
|
|
||||||
summary["completed_at"] = None
|
|
||||||
if completed_ids:
|
|
||||||
print(f"[TI Scheduler] Resuming — skipping {len(completed_ids)} "
|
|
||||||
f"completed experiment(s): {sorted(completed_ids)}", flush=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Scheduler] Could not read existing summary ({e}) — starting fresh",
|
|
||||||
flush=True)
|
|
||||||
completed_ids = set()
|
|
||||||
all_curve_data = []
|
|
||||||
summary = None
|
|
||||||
|
|
||||||
if not completed_ids:
|
|
||||||
summary = {
|
|
||||||
"sweep_name": sweep_name,
|
|
||||||
"description": description,
|
|
||||||
"sweep_file": str(exp_path),
|
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"completed_at": None,
|
|
||||||
"system": _get_system_info(),
|
|
||||||
"data_dir": str(data_dir),
|
|
||||||
"n_clips": n_clips,
|
|
||||||
"experiments": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
comparison_img_path = output_root / "loss_comparison.png"
|
|
||||||
|
|
||||||
def _write_summary():
|
|
||||||
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
||||||
|
|
||||||
def _save_comparison():
|
|
||||||
try:
|
|
||||||
img = _draw_comparison_curves(all_curve_data)
|
|
||||||
img.save(str(comparison_img_path))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[TI Scheduler] Comparison image failed: {e}", flush=True)
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. Run each experiment
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
trainer = SelvaTextualInversionTrainer()
|
|
||||||
pbar_outer = comfy.utils.ProgressBar(len(spec["experiments"]))
|
|
||||||
log_interval = 50 # matches _train_inner
|
|
||||||
|
|
||||||
for exp in spec["experiments"]:
|
|
||||||
exp_id = exp["id"]
|
|
||||||
exp_desc = exp.get("description", "")
|
|
||||||
|
|
||||||
if exp_id in completed_ids:
|
|
||||||
print(f"[TI Scheduler] Skipping '{exp_id}' (already completed)", flush=True)
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
cfg = _merge_config(base_cfg, exp)
|
|
||||||
|
|
||||||
n_tokens = int(cfg["n_tokens"])
|
|
||||||
lr = float(cfg["lr"])
|
|
||||||
steps = int(cfg["steps"])
|
|
||||||
batch_size = int(cfg["batch_size"])
|
|
||||||
warmup = int(cfg["warmup_steps"])
|
|
||||||
seed = int(cfg["seed"])
|
|
||||||
save_every = int(cfg["save_every"])
|
|
||||||
init_text = str(cfg["init_text"])
|
|
||||||
inject_mode = str(cfg["inject_mode"])
|
|
||||||
|
|
||||||
output_dir = output_root / exp_id
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
out_path = output_dir / "embeddings.pt"
|
|
||||||
|
|
||||||
print(f"\n[TI Scheduler] ── Experiment '{exp_id}' ──", flush=True)
|
|
||||||
if exp_desc:
|
|
||||||
print(f"[TI Scheduler] {exp_desc}", flush=True)
|
|
||||||
print(f"[TI Scheduler] n_tokens={n_tokens} lr={lr:.2e} steps={steps} "
|
|
||||||
f"batch_size={batch_size} warmup={warmup} seed={seed} "
|
|
||||||
f"inject_mode={inject_mode}", flush=True)
|
|
||||||
if init_text:
|
|
||||||
print(f"[TI Scheduler] init_text='{init_text}'", flush=True)
|
|
||||||
|
|
||||||
exp_record = {
|
|
||||||
"id": exp_id,
|
|
||||||
"description": exp_desc,
|
|
||||||
"config": {
|
|
||||||
"n_tokens": n_tokens,
|
|
||||||
"lr": lr,
|
|
||||||
"steps": steps,
|
|
||||||
"batch_size": batch_size,
|
|
||||||
"warmup_steps": warmup,
|
|
||||||
"seed": seed,
|
|
||||||
"save_every": save_every,
|
|
||||||
"init_text": init_text,
|
|
||||||
"inject_mode": inject_mode,
|
|
||||||
},
|
|
||||||
"results": {"status": "running"},
|
|
||||||
"embeddings_path": None,
|
|
||||||
"output_dir": str(output_dir),
|
|
||||||
}
|
|
||||||
summary["experiments"].append(exp_record)
|
|
||||||
_write_summary()
|
|
||||||
|
|
||||||
t_start = time.monotonic()
|
|
||||||
try:
|
|
||||||
with torch.inference_mode(False), torch.enable_grad():
|
|
||||||
r = trainer._train_inner(
|
|
||||||
model, dataset, feature_utils_orig, seq_cfg,
|
|
||||||
device, dtype, mode,
|
|
||||||
data_dir, out_path,
|
|
||||||
n_tokens, steps, lr, batch_size,
|
|
||||||
warmup, seed, save_every, init_text, inject_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
loss_history = r["loss_history"]
|
|
||||||
smoothed = _smooth_losses(loss_history) if loss_history else []
|
|
||||||
|
|
||||||
final_loss = round(smoothed[-1], 6) if smoothed else None
|
|
||||||
min_loss = round(min(smoothed), 6) if smoothed else None
|
|
||||||
min_idx = smoothed.index(min(smoothed)) if smoothed else None
|
|
||||||
min_loss_step = (min_idx + 1) * log_interval if min_idx is not None else None
|
|
||||||
|
|
||||||
loss_std_last_quarter = None
|
|
||||||
if loss_history:
|
|
||||||
quarter = max(1, len(loss_history) // 4)
|
|
||||||
loss_std_last_quarter = round(float(np.std(loss_history[-quarter:])), 6)
|
|
||||||
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "completed",
|
|
||||||
"final_loss": final_loss,
|
|
||||||
"min_loss": min_loss,
|
|
||||||
"min_loss_step": min_loss_step,
|
|
||||||
"loss_std_last_quarter": loss_std_last_quarter,
|
|
||||||
"loss_at_steps": _loss_at_steps(
|
|
||||||
loss_history, log_interval, save_every, steps
|
|
||||||
),
|
|
||||||
"loss_history": [round(v, 6) for v in loss_history],
|
|
||||||
"log_interval": log_interval,
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
exp_record["embeddings_path"] = r["embeddings_path"]
|
|
||||||
|
|
||||||
all_curve_data.append({
|
|
||||||
"id": exp_id,
|
|
||||||
"loss_history": loss_history,
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
duration = time.monotonic() - t_start
|
|
||||||
print(f"[TI Scheduler] Experiment '{exp_id}' failed: {e}", flush=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
exp_record["results"] = {
|
|
||||||
"status": "failed",
|
|
||||||
"error": str(e),
|
|
||||||
"duration_seconds": round(duration, 1),
|
|
||||||
}
|
|
||||||
_write_summary()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
_write_summary()
|
|
||||||
_save_comparison()
|
|
||||||
pbar_outer.update(1)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. Finalise
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
summary["completed_at"] = datetime.now(timezone.utc).isoformat()
|
|
||||||
_write_summary()
|
|
||||||
print(f"\n[TI Scheduler] Sweep complete. Summary: {summary_path}", flush=True)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 7. Comparison image (final update, then return to ComfyUI)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
_save_comparison()
|
|
||||||
comparison_img = _draw_comparison_curves(all_curve_data)
|
|
||||||
return (str(summary_path), _pil_to_tensor(comparison_img))
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
"""SelVA VAE Roundtrip — encode audio through the VAE then decode straight back.
|
|
||||||
|
|
||||||
Useful for diagnosing codec reconstruction quality: if the output sounds
|
|
||||||
saturated/degraded compared to the input, the VAE/DAC is the bottleneck,
|
|
||||||
not the diffusion model or LoRA.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
|
|
||||||
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
|
||||||
|
|
||||||
|
|
||||||
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
|
||||||
|
|
||||||
|
|
||||||
class SelvaVaeRoundtrip:
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("SELVA_MODEL",),
|
|
||||||
"audio": ("AUDIO",),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio_reconstructed",)
|
|
||||||
OUTPUT_TOOLTIPS = (
|
|
||||||
"Audio after VAE encode → decode roundtrip. "
|
|
||||||
"Compare to the input to hear codec reconstruction quality.",
|
|
||||||
)
|
|
||||||
FUNCTION = "roundtrip"
|
|
||||||
CATEGORY = SELVA_CATEGORY
|
|
||||||
DESCRIPTION = (
|
|
||||||
"Encodes the input audio through the SelVA VAE then decodes it straight back. "
|
|
||||||
"Use this to isolate codec reconstruction quality from generation quality. "
|
|
||||||
"If the output sounds degraded compared to the input, the VAE/DAC is the "
|
|
||||||
"bottleneck — not the model or LoRA."
|
|
||||||
)
|
|
||||||
|
|
||||||
def roundtrip(self, model, audio):
|
|
||||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
|
||||||
|
|
||||||
mode = model["mode"]
|
|
||||||
seq_cfg = model["seq_cfg"]
|
|
||||||
dtype = model["dtype"]
|
|
||||||
device = get_device()
|
|
||||||
generator = model["generator"]
|
|
||||||
feature_utils = model["feature_utils"]
|
|
||||||
|
|
||||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
|
||||||
vae_path = _SELVA_DIR / "ext" / vae_name
|
|
||||||
if not vae_path.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"[VAE Roundtrip] VAE weight not found: {vae_path}. "
|
|
||||||
"Run SelVA Model Loader first to auto-download weights."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load encoder only — decoder/vocoder come from model["feature_utils"]
|
|
||||||
# to mirror exactly what the sampler uses.
|
|
||||||
# AutoEncoderModule requires vocoder_ckpt_path even when only encoding,
|
|
||||||
# so pass the BigVGAN path (weights won't actually be used for decode here).
|
|
||||||
bigvgan_path = _SELVA_DIR / "ext" / "best_netG.pt"
|
|
||||||
print("[VAE Roundtrip] Loading VAE encoder...", flush=True)
|
|
||||||
vae_enc = FeaturesUtils(
|
|
||||||
tod_vae_ckpt=str(vae_path),
|
|
||||||
enable_conditions=False,
|
|
||||||
mode=mode,
|
|
||||||
need_vae_encoder=True,
|
|
||||||
bigvgan_vocoder_ckpt=str(bigvgan_path) if bigvgan_path.exists() else None,
|
|
||||||
).to(device).eval()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Prepare input audio
|
|
||||||
waveform = audio["waveform"] # [1, C, L]
|
|
||||||
sr_in = audio["sample_rate"]
|
|
||||||
|
|
||||||
wav = waveform[0].mean(0) # mono [L]
|
|
||||||
|
|
||||||
if sr_in != seq_cfg.sampling_rate:
|
|
||||||
wav = torchaudio.functional.resample(
|
|
||||||
wav.unsqueeze(0), sr_in, seq_cfg.sampling_rate
|
|
||||||
).squeeze(0)
|
|
||||||
print(f"[VAE Roundtrip] Resampled {sr_in} → {seq_cfg.sampling_rate} Hz",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
target_len = int(seq_cfg.duration * seq_cfg.sampling_rate)
|
|
||||||
if wav.shape[0] > target_len:
|
|
||||||
wav = wav[:target_len]
|
|
||||||
elif wav.shape[0] < target_len:
|
|
||||||
wav = F.pad(wav, (0, target_len - wav.shape[0]))
|
|
||||||
|
|
||||||
wav_b = wav.unsqueeze(0).to(device).float() # [1, L]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# Encode: audio → raw latent [1, latent_dim, T]
|
|
||||||
dist = vae_enc.encode_audio(wav_b)
|
|
||||||
latent = dist.mode().clone()
|
|
||||||
|
|
||||||
# Trim/pad to exact model sequence length (same as _prepare_dataset)
|
|
||||||
tgt = seq_cfg.latent_seq_len
|
|
||||||
if latent.shape[2] < tgt:
|
|
||||||
latent = F.pad(latent, (0, tgt - latent.shape[2]))
|
|
||||||
elif latent.shape[2] > tgt:
|
|
||||||
latent = latent[:, :, :tgt]
|
|
||||||
|
|
||||||
# To [B, T, latent_dim] — layout the generator uses
|
|
||||||
latent_t = latent.transpose(1, 2).to(dtype)
|
|
||||||
print(f"[VAE Roundtrip] Encoded: mean={latent_t.mean():.4f} std={latent_t.std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Normalize → unnormalize mirrors the training/inference pipeline:
|
|
||||||
# training normalizes encoded latents; sampler unnormalizes before decode.
|
|
||||||
# This ensures the latent is in the same space the decoder expects.
|
|
||||||
latent_norm = generator.normalize(latent_t.clone())
|
|
||||||
latent_unnorm = generator.unnormalize(latent_norm)
|
|
||||||
print(f"[VAE Roundtrip] Norm→unnorm: mean={latent_unnorm.mean():.4f} std={latent_unnorm.std():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
# Decode using model's feature_utils — same path as the sampler
|
|
||||||
tod = feature_utils.tod
|
|
||||||
tod_orig_device = next(tod.parameters()).device
|
|
||||||
tod.to(device)
|
|
||||||
try:
|
|
||||||
spec = feature_utils.decode(latent_unnorm)
|
|
||||||
out = feature_utils.vocode(spec)
|
|
||||||
finally:
|
|
||||||
tod.to(tod_orig_device)
|
|
||||||
|
|
||||||
out = out.float().cpu()
|
|
||||||
if out.dim() == 1:
|
|
||||||
out = out.unsqueeze(0).unsqueeze(0)
|
|
||||||
elif out.dim() == 2:
|
|
||||||
out = out.unsqueeze(1)
|
|
||||||
elif out.dim() == 3 and out.shape[1] != 1:
|
|
||||||
out = out.mean(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
rms = out.pow(2).mean().sqrt().clamp(min=1e-8)
|
|
||||||
target_rms = 10 ** (-27.0 / 20.0)
|
|
||||||
out = out * (target_rms / rms)
|
|
||||||
out = out.clamp(-1.0, 1.0)
|
|
||||||
|
|
||||||
print(f"[VAE Roundtrip] Output: shape={tuple(out.shape)} "
|
|
||||||
f"peak={out.abs().max():.4f} rms={out.pow(2).mean().sqrt():.4f}",
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
del vae_enc
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
return ({"waveform": out, "sample_rate": seq_cfg.sampling_rate},)
|
|
||||||
@@ -0,0 +1,160 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_management as mm
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||||
|
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
||||||
|
)
|
||||||
|
from .sampler import _substitute_empty_features
|
||||||
|
|
||||||
|
|
||||||
|
class PrismAudioTextOnly:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("PRISMAUDIO_MODEL",),
|
||||||
|
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
||||||
|
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||||
|
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||||
|
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = PRISMAUDIO_CATEGORY
|
||||||
|
|
||||||
|
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
diffusion = model["model"]
|
||||||
|
|
||||||
|
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||||
|
|
||||||
|
# Encode text with T5-Gemma
|
||||||
|
text_features = _encode_text_t5(text_prompt, device, dtype)
|
||||||
|
|
||||||
|
# Build metadata: tuple of one dict per sample
|
||||||
|
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
||||||
|
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
||||||
|
# These will be substituted with learned empty embeddings after conditioning
|
||||||
|
sample_meta = {
|
||||||
|
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
||||||
|
"text_features": text_features.to(device, dtype=dtype),
|
||||||
|
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
||||||
|
"video_exist": torch.tensor(False),
|
||||||
|
}
|
||||||
|
metadata = (sample_meta,)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(device)
|
||||||
|
diffusion.conditioner.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
conditioning = diffusion.conditioner(metadata, device)
|
||||||
|
|
||||||
|
# Substitute empty features for video/sync
|
||||||
|
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||||
|
|
||||||
|
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||||
|
|
||||||
|
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||||
|
gen_device = "cpu" if device.type == "mps" else device
|
||||||
|
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||||
|
noise = torch.randn(
|
||||||
|
[1, IO_CHANNELS, latent_length],
|
||||||
|
generator=generator,
|
||||||
|
device=gen_device,
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||||
|
|
||||||
|
def on_step(info):
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
fakes = sample_discrete_euler(
|
||||||
|
diffusion.model,
|
||||||
|
noise,
|
||||||
|
steps,
|
||||||
|
callback=on_step,
|
||||||
|
**cond_inputs,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
batch_cfg=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
fakes_f = fakes.float()
|
||||||
|
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.model.to(get_offload_device())
|
||||||
|
diffusion.conditioner.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
diffusion.pretransform.to(device)
|
||||||
|
|
||||||
|
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||||
|
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||||
|
audio = diffusion.pretransform.decode(fakes_f)
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
diffusion.pretransform.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Peak normalize then clamp
|
||||||
|
audio = audio.float()
|
||||||
|
pre_norm_std = audio.std().item()
|
||||||
|
pre_norm_peak = audio.abs().max().item()
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||||
|
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
||||||
|
|
||||||
|
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||||
|
|
||||||
|
|
||||||
|
# T5-Gemma encoder singleton
|
||||||
|
_t5_model = None
|
||||||
|
_t5_tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_text_t5(text, device, dtype):
|
||||||
|
"""Encode text using T5-Gemma.
|
||||||
|
|
||||||
|
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
||||||
|
FeaturesUtils.encode_t5_text() implementation.
|
||||||
|
No truncation applied (matching reference behavior).
|
||||||
|
"""
|
||||||
|
global _t5_model, _t5_tokenizer
|
||||||
|
|
||||||
|
if _t5_model is None:
|
||||||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||||
|
model_id = "google/t5gemma-l-l-ul2-it"
|
||||||
|
token = resolve_hf_token()
|
||||||
|
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
||||||
|
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
||||||
|
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
||||||
|
_t5_model.eval()
|
||||||
|
|
||||||
|
_t5_model.to(device, dtype=dtype)
|
||||||
|
|
||||||
|
tokens = _t5_tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = _t5_model(**tokens)
|
||||||
|
|
||||||
|
# Move T5 off GPU after encoding to save VRAM
|
||||||
|
_t5_model.to("cpu")
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
||||||
+47
-4
@@ -1,7 +1,21 @@
|
|||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import folder_paths
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
|
||||||
SELVA_CATEGORY = "SelVA"
|
PRISMAUDIO_CATEGORY = "PrismAudio"
|
||||||
|
SAMPLE_RATE = 44100
|
||||||
|
DOWNSAMPLING_RATIO = 2048
|
||||||
|
IO_CHANNELS = 64
|
||||||
|
|
||||||
|
def get_prismaudio_model_dir():
|
||||||
|
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
return model_dir
|
||||||
|
|
||||||
|
def register_model_folder():
|
||||||
|
model_dir = get_prismaudio_model_dir()
|
||||||
|
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
||||||
|
|
||||||
def get_device():
|
def get_device():
|
||||||
return mm.get_torch_device()
|
return mm.get_torch_device()
|
||||||
@@ -9,13 +23,42 @@ def get_device():
|
|||||||
def get_offload_device():
|
def get_offload_device():
|
||||||
return mm.unet_offload_device()
|
return mm.unet_offload_device()
|
||||||
|
|
||||||
|
def get_free_memory(device=None):
|
||||||
|
if device is None:
|
||||||
|
device = get_device()
|
||||||
|
return mm.get_free_memory(device)
|
||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
|
def determine_precision(preference, device):
|
||||||
|
if preference != "auto":
|
||||||
|
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
||||||
|
if device.type == "cpu":
|
||||||
|
return torch.float32
|
||||||
|
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
def determine_offload_strategy(preference):
|
def determine_offload_strategy(preference):
|
||||||
if preference != "auto":
|
if preference != "auto":
|
||||||
return preference
|
return preference
|
||||||
free_mem = mm.get_free_memory(get_device())
|
free_mem = get_free_memory()
|
||||||
if free_mem / (1024 ** 3) >= 16:
|
gb = free_mem / (1024 ** 3)
|
||||||
|
if gb >= 24:
|
||||||
return "keep_in_vram"
|
return "keep_in_vram"
|
||||||
return "offload_to_cpu"
|
else:
|
||||||
|
return "offload_to_cpu"
|
||||||
|
|
||||||
|
def try_import_flash_attn():
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
return flash_attn
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resolve_hf_token():
|
||||||
|
env_token = os.environ.get("HF_TOKEN")
|
||||||
|
if env_token:
|
||||||
|
return env_token
|
||||||
|
return None
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
PrismAudio core inference modules.
|
||||||
|
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
|
||||||
|
Only inference-critical code — no training, no JAX/TF dependencies.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,141 @@
|
|||||||
|
{
|
||||||
|
"model_type": "diffusion_cond",
|
||||||
|
"sample_size": 397312,
|
||||||
|
"sample_rate": 44100,
|
||||||
|
"audio_channels": 2,
|
||||||
|
"model": {
|
||||||
|
"pretransform": {
|
||||||
|
"type": "autoencoder",
|
||||||
|
"iterate_batch": true,
|
||||||
|
"config": {
|
||||||
|
"encoder": {
|
||||||
|
"type": "oobleck",
|
||||||
|
"config": {
|
||||||
|
"in_channels": 2,
|
||||||
|
"channels": 128,
|
||||||
|
"c_mults": [1, 2, 4, 8, 16],
|
||||||
|
"strides": [2, 4, 4, 8, 8],
|
||||||
|
"latent_dim": 128,
|
||||||
|
"use_snake": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"decoder": {
|
||||||
|
"type": "oobleck",
|
||||||
|
"config": {
|
||||||
|
"out_channels": 2,
|
||||||
|
"channels": 128,
|
||||||
|
"c_mults": [1, 2, 4, 8, 16],
|
||||||
|
"strides": [2, 4, 4, 8, 8],
|
||||||
|
"latent_dim": 64,
|
||||||
|
"use_snake": true,
|
||||||
|
"final_tanh": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bottleneck": {
|
||||||
|
"type": "vae"
|
||||||
|
},
|
||||||
|
"latent_dim": 64,
|
||||||
|
"downsampling_ratio": 2048,
|
||||||
|
"io_channels": 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"conditioning": {
|
||||||
|
"configs": [
|
||||||
|
{
|
||||||
|
"id": "video_features",
|
||||||
|
"type": "cond_mlp",
|
||||||
|
"config": {
|
||||||
|
"dim": 1024,
|
||||||
|
"output_dim": 1024
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "text_features",
|
||||||
|
"type": "cond_mlp",
|
||||||
|
"config": {
|
||||||
|
"dim": 1024,
|
||||||
|
"output_dim": 1024
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "sync_features",
|
||||||
|
"type": "sync_mlp",
|
||||||
|
"config": {
|
||||||
|
"dim": 768,
|
||||||
|
"output_dim": 1024
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"cond_dim": 768
|
||||||
|
},
|
||||||
|
"diffusion": {
|
||||||
|
"cross_attention_cond_ids": ["video_features","text_features"],
|
||||||
|
"add_cond_ids": ["video_features"],
|
||||||
|
"sync_cond_ids": ["sync_features"],
|
||||||
|
"type": "dit",
|
||||||
|
"diffusion_objective": "rectified_flow",
|
||||||
|
"config": {
|
||||||
|
"io_channels": 64,
|
||||||
|
"embed_dim": 1024,
|
||||||
|
"depth": 24,
|
||||||
|
"num_heads": 16,
|
||||||
|
"cond_token_dim": 1024,
|
||||||
|
"add_token_dim": 1024,
|
||||||
|
"sync_token_dim": 1024,
|
||||||
|
"project_cond_tokens": false,
|
||||||
|
"transformer_type": "continuous_transformer",
|
||||||
|
"attn_kwargs":{
|
||||||
|
"qk_norm": "rns"
|
||||||
|
},
|
||||||
|
"use_gated": true,
|
||||||
|
"use_sync_gated": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"io_channels": 64
|
||||||
|
},
|
||||||
|
"training": {
|
||||||
|
"use_ema": true,
|
||||||
|
"log_loss_info": false,
|
||||||
|
"cfg_dropout_prob": 0.1,
|
||||||
|
"pre_encoded": true,
|
||||||
|
"timestep_sampler": "trunc_logit_normal",
|
||||||
|
"optimizer_configs": {
|
||||||
|
"diffusion": {
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"config": {
|
||||||
|
"lr": 1e-4,
|
||||||
|
"betas": [0.9, 0.999],
|
||||||
|
"weight_decay": 1e-3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "InverseLR",
|
||||||
|
"config": {
|
||||||
|
"inv_gamma": 100000,
|
||||||
|
"power": 0.5,
|
||||||
|
"warmup": 0.99
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"demo": {
|
||||||
|
"demo_every": 5000,
|
||||||
|
"demo_steps": 24,
|
||||||
|
"num_demos": 10,
|
||||||
|
"demo_cond": [
|
||||||
|
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
|
||||||
|
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
|
||||||
|
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
|
||||||
|
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
|
||||||
|
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
|
||||||
|
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
|
||||||
|
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
|
||||||
|
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
|
||||||
|
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
|
||||||
|
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
|
||||||
|
],
|
||||||
|
"demo_cfg_scales": [5]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,413 @@
|
|||||||
|
"""
|
||||||
|
Model factory functions for PrismAudio inference.
|
||||||
|
|
||||||
|
Extracted from:
|
||||||
|
- PrismAudio/models/factory.py
|
||||||
|
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
|
||||||
|
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
|
||||||
|
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
|
||||||
|
|
||||||
|
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
|
||||||
|
Only inference-critical factory functions are retained.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import typing as tp
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_from_config(model_config):
|
||||||
|
model_type = model_config.get('model_type', None)
|
||||||
|
|
||||||
|
assert model_type is not None, 'model_type must be specified in model config'
|
||||||
|
|
||||||
|
if model_type == 'autoencoder':
|
||||||
|
return create_autoencoder_from_config(model_config)
|
||||||
|
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
|
||||||
|
return create_diffusion_cond_from_config(model_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
|
||||||
|
def create_pretransform_from_config(pretransform_config, sample_rate):
|
||||||
|
pretransform_type = pretransform_config.get('type', None)
|
||||||
|
|
||||||
|
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
||||||
|
|
||||||
|
if pretransform_type == 'autoencoder':
|
||||||
|
from prismaudio_core.models.pretransforms import AutoencoderPretransform
|
||||||
|
|
||||||
|
# Create fake top-level config to pass sample rate to autoencoder constructor
|
||||||
|
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
||||||
|
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
||||||
|
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
||||||
|
|
||||||
|
scale = pretransform_config.get("scale", 1.0)
|
||||||
|
model_half = pretransform_config.get("model_half", False)
|
||||||
|
iterate_batch = pretransform_config.get("iterate_batch", False)
|
||||||
|
chunked = pretransform_config.get("chunked", False)
|
||||||
|
|
||||||
|
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
||||||
|
elif pretransform_type == 'wavelet':
|
||||||
|
raise NotImplementedError("wavelet pretransform type is not supported")
|
||||||
|
elif pretransform_type == 'pqmf':
|
||||||
|
from prismaudio_core.models.pretransforms import PQMFPretransform
|
||||||
|
pqmf_config = pretransform_config["config"]
|
||||||
|
pretransform = PQMFPretransform(**pqmf_config)
|
||||||
|
elif pretransform_type == 'dac_pretrained':
|
||||||
|
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
|
||||||
|
pretrained_dac_config = pretransform_config["config"]
|
||||||
|
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
||||||
|
elif pretransform_type == "audiocraft_pretrained":
|
||||||
|
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
|
||||||
|
|
||||||
|
audiocraft_config = pretransform_config["config"]
|
||||||
|
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
||||||
|
|
||||||
|
enable_grad = pretransform_config.get('enable_grad', False)
|
||||||
|
pretransform.enable_grad = enable_grad
|
||||||
|
|
||||||
|
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
||||||
|
|
||||||
|
return pretransform
|
||||||
|
|
||||||
|
|
||||||
|
def create_bottleneck_from_config(bottleneck_config):
|
||||||
|
bottleneck_type = bottleneck_config.get('type', None)
|
||||||
|
|
||||||
|
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
||||||
|
|
||||||
|
if bottleneck_type == 'tanh':
|
||||||
|
from prismaudio_core.models.bottleneck import TanhBottleneck
|
||||||
|
bottleneck = TanhBottleneck()
|
||||||
|
elif bottleneck_type == 'vae':
|
||||||
|
from prismaudio_core.models.bottleneck import VAEBottleneck
|
||||||
|
bottleneck = VAEBottleneck()
|
||||||
|
elif bottleneck_type == 'rvq':
|
||||||
|
from prismaudio_core.models.bottleneck import RVQBottleneck
|
||||||
|
|
||||||
|
quantizer_params = {
|
||||||
|
"dim": 128,
|
||||||
|
"codebook_size": 1024,
|
||||||
|
"num_quantizers": 8,
|
||||||
|
"decay": 0.99,
|
||||||
|
"kmeans_init": True,
|
||||||
|
"kmeans_iters": 50,
|
||||||
|
"threshold_ema_dead_code": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
quantizer_params.update(bottleneck_config["config"])
|
||||||
|
|
||||||
|
bottleneck = RVQBottleneck(**quantizer_params)
|
||||||
|
elif bottleneck_type == "dac_rvq":
|
||||||
|
from prismaudio_core.models.bottleneck import DACRVQBottleneck
|
||||||
|
|
||||||
|
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
||||||
|
|
||||||
|
elif bottleneck_type == 'rvq_vae':
|
||||||
|
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
|
||||||
|
|
||||||
|
quantizer_params = {
|
||||||
|
"dim": 128,
|
||||||
|
"codebook_size": 1024,
|
||||||
|
"num_quantizers": 8,
|
||||||
|
"decay": 0.99,
|
||||||
|
"kmeans_init": True,
|
||||||
|
"kmeans_iters": 50,
|
||||||
|
"threshold_ema_dead_code": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
quantizer_params.update(bottleneck_config["config"])
|
||||||
|
|
||||||
|
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
||||||
|
|
||||||
|
elif bottleneck_type == 'dac_rvq_vae':
|
||||||
|
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
|
||||||
|
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
||||||
|
elif bottleneck_type == 'l2_norm':
|
||||||
|
from prismaudio_core.models.bottleneck import L2Bottleneck
|
||||||
|
bottleneck = L2Bottleneck()
|
||||||
|
elif bottleneck_type == "wasserstein":
|
||||||
|
from prismaudio_core.models.bottleneck import WassersteinBottleneck
|
||||||
|
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
||||||
|
elif bottleneck_type == "fsq":
|
||||||
|
from prismaudio_core.models.bottleneck import FSQBottleneck
|
||||||
|
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
||||||
|
|
||||||
|
requires_grad = bottleneck_config.get('requires_grad', True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in bottleneck.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return bottleneck
|
||||||
|
|
||||||
|
|
||||||
|
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||||
|
"""Create an AudioAutoencoder from a config dictionary.
|
||||||
|
|
||||||
|
Originally in PrismAudio/models/autoencoders.py.
|
||||||
|
"""
|
||||||
|
from prismaudio_core.models.autoencoders import (
|
||||||
|
AudioAutoencoder,
|
||||||
|
create_encoder_from_config,
|
||||||
|
create_decoder_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
ae_config = config["model"]
|
||||||
|
|
||||||
|
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||||
|
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||||
|
|
||||||
|
bottleneck = ae_config.get("bottleneck", None)
|
||||||
|
|
||||||
|
latent_dim = ae_config.get("latent_dim", None)
|
||||||
|
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||||
|
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||||
|
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||||
|
io_channels = ae_config.get("io_channels", None)
|
||||||
|
assert io_channels is not None, "io_channels must be specified in model config"
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||||
|
|
||||||
|
in_channels = ae_config.get("in_channels", None)
|
||||||
|
out_channels = ae_config.get("out_channels", None)
|
||||||
|
|
||||||
|
pretransform = ae_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
if bottleneck is not None:
|
||||||
|
bottleneck = create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||||
|
|
||||||
|
return AudioAutoencoder(
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
io_channels=io_channels,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
downsampling_ratio=downsampling_ratio,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
pretransform=pretransform,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
soft_clip=soft_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
"""Create a MultiConditioner from a conditioning config dictionary.
|
||||||
|
|
||||||
|
Originally in PrismAudio/models/conditioners.py.
|
||||||
|
"""
|
||||||
|
from prismaudio_core.models.conditioners import (
|
||||||
|
MultiConditioner,
|
||||||
|
T5Conditioner,
|
||||||
|
CLAPTextConditioner,
|
||||||
|
CLIPTextConditioner,
|
||||||
|
MetaCLIPTextConditioner,
|
||||||
|
CLAPAudioConditioner,
|
||||||
|
Cond_MLP,
|
||||||
|
Global_MLP,
|
||||||
|
Sync_MLP,
|
||||||
|
Cond_MLP_1,
|
||||||
|
Cond_ConvMLP,
|
||||||
|
Cond_MLP_Global,
|
||||||
|
Cond_MLP_Global_1,
|
||||||
|
Cond_MLP_Global_2,
|
||||||
|
Video_Global,
|
||||||
|
Video_Sync,
|
||||||
|
Text_Linear,
|
||||||
|
CLIPConditioner,
|
||||||
|
IntConditioner,
|
||||||
|
NumberConditioner,
|
||||||
|
PhonemeConditioner,
|
||||||
|
TokenizerLUTConditioner,
|
||||||
|
PretransformConditioner,
|
||||||
|
mm_unchang,
|
||||||
|
)
|
||||||
|
from prismaudio_core.models.utils import load_ckpt_state_dict
|
||||||
|
|
||||||
|
conditioners = {}
|
||||||
|
cond_dim = config["cond_dim"]
|
||||||
|
|
||||||
|
default_keys = config.get("default_keys", {})
|
||||||
|
|
||||||
|
for conditioner_info in config["configs"]:
|
||||||
|
id = conditioner_info["id"]
|
||||||
|
|
||||||
|
conditioner_type = conditioner_info["type"]
|
||||||
|
|
||||||
|
conditioner_config = {"output_dim": cond_dim}
|
||||||
|
|
||||||
|
conditioner_config.update(conditioner_info["config"])
|
||||||
|
if conditioner_type == "t5":
|
||||||
|
conditioners[id] = T5Conditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clap_text":
|
||||||
|
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clip_text":
|
||||||
|
conditioners[id] = CLIPTextConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "metaclip_text":
|
||||||
|
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "clap_audio":
|
||||||
|
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_mlp":
|
||||||
|
conditioners[id] = Cond_MLP(**conditioner_config)
|
||||||
|
elif conditioner_type == "global_mlp":
|
||||||
|
conditioners[id] = Global_MLP(**conditioner_config)
|
||||||
|
elif conditioner_type == "sync_mlp":
|
||||||
|
conditioners[id] = Sync_MLP(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_mlp_1":
|
||||||
|
conditioners[id] = Cond_MLP_1(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_convmlp":
|
||||||
|
conditioners[id] = Cond_ConvMLP(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_mlp_global":
|
||||||
|
conditioners[id] = Cond_MLP_Global(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_mlp_global_1":
|
||||||
|
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
|
||||||
|
elif conditioner_type == "cond_mlp_global_2":
|
||||||
|
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
|
||||||
|
elif conditioner_type == "video_global":
|
||||||
|
conditioners[id] = Video_Global(**conditioner_config)
|
||||||
|
elif conditioner_type == "video_sync":
|
||||||
|
conditioners[id] = Video_Sync(**conditioner_config)
|
||||||
|
elif conditioner_type == "text_linear":
|
||||||
|
conditioners[id] = Text_Linear(**conditioner_config)
|
||||||
|
elif conditioner_type == "video_clip":
|
||||||
|
conditioners[id] = CLIPConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "int":
|
||||||
|
conditioners[id] = IntConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "number":
|
||||||
|
conditioners[id] = NumberConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "phoneme":
|
||||||
|
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "lut":
|
||||||
|
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
||||||
|
elif conditioner_type == "pretransform":
|
||||||
|
sample_rate = conditioner_config.pop("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
||||||
|
|
||||||
|
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
||||||
|
|
||||||
|
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
||||||
|
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
||||||
|
|
||||||
|
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
||||||
|
elif conditioner_type == "mm_unchang":
|
||||||
|
conditioners[id] = mm_unchang(**conditioner_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
||||||
|
|
||||||
|
return MultiConditioner(conditioners, default_keys=default_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
|
||||||
|
|
||||||
|
Originally in PrismAudio/models/diffusion.py.
|
||||||
|
"""
|
||||||
|
from prismaudio_core.models.diffusion import (
|
||||||
|
ConditionedDiffusionModelWrapper,
|
||||||
|
MMConditionedDiffusionModelWrapper,
|
||||||
|
UNetCFG1DWrapper,
|
||||||
|
UNet1DCondWrapper,
|
||||||
|
DiTWrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = config["model"]
|
||||||
|
|
||||||
|
model_type = config["model_type"]
|
||||||
|
|
||||||
|
diffusion_config = model_config.get('diffusion', None)
|
||||||
|
assert diffusion_config is not None, "Must specify diffusion config"
|
||||||
|
|
||||||
|
diffusion_model_type = diffusion_config.get('type', None)
|
||||||
|
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||||
|
|
||||||
|
diffusion_model_config = diffusion_config.get('config', None)
|
||||||
|
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||||
|
|
||||||
|
if diffusion_model_type == 'adp_cfg_1d':
|
||||||
|
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'adp_1d':
|
||||||
|
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'dit':
|
||||||
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'mmdit':
|
||||||
|
raise NotImplementedError("mmdit diffusion model type is not supported")
|
||||||
|
|
||||||
|
io_channels = model_config.get('io_channels', None)
|
||||||
|
assert io_channels is not None, "Must specify io_channels in model config"
|
||||||
|
|
||||||
|
sample_rate = config.get('sample_rate', None)
|
||||||
|
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||||
|
|
||||||
|
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||||
|
|
||||||
|
conditioning_config = model_config.get('conditioning', None)
|
||||||
|
|
||||||
|
conditioner = None
|
||||||
|
if conditioning_config is not None:
|
||||||
|
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
||||||
|
|
||||||
|
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||||
|
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||||
|
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||||
|
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||||
|
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||||
|
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||||
|
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||||
|
zero_init = diffusion_config.get('zero_init', False)
|
||||||
|
pretransform = model_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||||
|
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
min_input_length *= diffusion_model.model.patch_size
|
||||||
|
|
||||||
|
# Get the proper wrapper class
|
||||||
|
|
||||||
|
extra_kwargs = {}
|
||||||
|
|
||||||
|
if model_type == "mm_diffusion_cond":
|
||||||
|
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||||
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||||
|
|
||||||
|
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||||
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||||
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
|
elif model_type == "diffusion_prior":
|
||||||
|
raise NotImplementedError("diffusion_prior model type is not supported")
|
||||||
|
|
||||||
|
return wrapper_fn(
|
||||||
|
diffusion_model,
|
||||||
|
conditioner,
|
||||||
|
min_input_length=min_input_length,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
cross_attn_cond_ids=cross_attention_ids,
|
||||||
|
global_cond_ids=global_cond_ids,
|
||||||
|
input_concat_ids=input_concat_ids,
|
||||||
|
prepend_cond_ids=prepend_cond_ids,
|
||||||
|
add_cond_ids=add_cond_ids,
|
||||||
|
sync_cond_ids=sync_cond_ids,
|
||||||
|
pretransform=pretransform,
|
||||||
|
io_channels=io_channels,
|
||||||
|
zero_init=zero_init,
|
||||||
|
**extra_kwargs
|
||||||
|
)
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .sampling import sample_discrete_euler
|
||||||
|
from .utils import set_audio_channels, prepare_audio
|
||||||
|
|
||||||
|
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
||||||
|
"""Discrete Euler sampler for rectified flow, with optional callback.
|
||||||
|
|
||||||
|
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
||||||
|
Original uses tqdm internally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The diffusion model (DiTWrapper)
|
||||||
|
x: Initial noise tensor [B, C, T]
|
||||||
|
steps: Number of sampling steps
|
||||||
|
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
||||||
|
callback: Optional callable({"i": step, "x": current_x}) for progress
|
||||||
|
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
||||||
|
sync_cond, cfg_scale, batch_cfg, etc.
|
||||||
|
"""
|
||||||
|
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
||||||
|
dt = t_next - t_curr
|
||||||
|
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
||||||
|
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({"i": i, "x": x})
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
|
||||||
|
|
||||||
|
def set_audio_channels(audio, target_channels):
|
||||||
|
"""Convert audio tensor to target number of channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio tensor of shape [B, C, T]
|
||||||
|
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor with the target number of channels.
|
||||||
|
"""
|
||||||
|
if target_channels == 1:
|
||||||
|
# Convert to mono
|
||||||
|
audio = audio.mean(1, keepdim=True)
|
||||||
|
elif target_channels == 2:
|
||||||
|
# Convert to stereo
|
||||||
|
if audio.shape[1] == 1:
|
||||||
|
audio = audio.repeat(1, 2, 1)
|
||||||
|
elif audio.shape[1] > 2:
|
||||||
|
audio = audio[:, :2, :]
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||||
|
"""Resample, pad/trim, and convert channels of an audio tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
||||||
|
in_sr: Input sample rate
|
||||||
|
target_sr: Target sample rate
|
||||||
|
target_length: Target length in samples (padded or cropped)
|
||||||
|
target_channels: Target number of channels
|
||||||
|
device: Torch device to place the audio on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio tensor of shape [B, target_channels, target_length] on device.
|
||||||
|
"""
|
||||||
|
audio = audio.to(device)
|
||||||
|
|
||||||
|
if in_sr != target_sr:
|
||||||
|
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
# Add batch dimension
|
||||||
|
if audio.dim() == 1:
|
||||||
|
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
# Pad or crop to target_length
|
||||||
|
if audio.shape[-1] < target_length:
|
||||||
|
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
||||||
|
elif audio.shape[-1] > target_length:
|
||||||
|
audio = audio[:, :, :target_length]
|
||||||
|
|
||||||
|
audio = set_audio_channels(audio, target_channels)
|
||||||
|
|
||||||
|
return audio
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
PrismAudio model modules for inference.
|
||||||
|
|
||||||
|
Re-exports create_model_from_config from the factory module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from prismaudio_core.factory import create_model_from_config
|
||||||
|
|
||||||
|
__all__ = ["create_model_from_config"]
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,821 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torchaudio import transforms as T
|
||||||
|
from alias_free_torch import Activation1d
|
||||||
|
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||||
|
from typing import Literal, Dict, Any
|
||||||
|
|
||||||
|
from .blocks import SnakeBeta
|
||||||
|
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||||
|
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||||
|
from .pretransforms import Pretransform
|
||||||
|
from .utils import checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||||
|
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
||||||
|
import torchaudio.transforms as T
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if in_sr != target_sr:
|
||||||
|
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
|
||||||
|
if audio.shape[0] > target_channels:
|
||||||
|
audio = audio[:target_channels]
|
||||||
|
elif audio.shape[0] < target_channels:
|
||||||
|
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
||||||
|
|
||||||
|
if audio.shape[-1] < target_length:
|
||||||
|
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
||||||
|
elif audio.shape[-1] > target_length:
|
||||||
|
audio = audio[..., :target_length]
|
||||||
|
|
||||||
|
return audio.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
||||||
|
from prismaudio_core.factory import create_pretransform_from_config
|
||||||
|
return create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_create_bottleneck_from_config(bottleneck):
|
||||||
|
from prismaudio_core.factory import create_bottleneck_from_config
|
||||||
|
return create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||||
|
if activation == "elu":
|
||||||
|
act = nn.ELU()
|
||||||
|
elif activation == "snake":
|
||||||
|
act = SnakeBeta(channels)
|
||||||
|
elif activation == "none":
|
||||||
|
act = nn.Identity()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation {activation}")
|
||||||
|
|
||||||
|
if antialias:
|
||||||
|
act = Activation1d(act)
|
||||||
|
|
||||||
|
return act
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
padding = (dilation * (7-1)) // 2
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=7, dilation=dilation, padding=padding),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||||
|
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
kernel_size=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = x
|
||||||
|
|
||||||
|
#x = checkpoint(self.layers, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
|
||||||
|
return x + res
|
||||||
|
|
||||||
|
class EncoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=in_channels,
|
||||||
|
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if use_nearest_upsample:
|
||||||
|
upsample_layer = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||||
|
WNConv1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride,
|
||||||
|
stride=1,
|
||||||
|
bias=False,
|
||||||
|
padding='same')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||||
|
upsample_layer,
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=1, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=3, use_snake=use_snake),
|
||||||
|
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||||
|
dilation=9, use_snake=use_snake),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
class OobleckEncoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1):
|
||||||
|
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OobleckDecoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
out_channels=2,
|
||||||
|
channels=128,
|
||||||
|
latent_dim=32,
|
||||||
|
c_mults = [1, 2, 4, 8],
|
||||||
|
strides = [2, 4, 8, 8],
|
||||||
|
use_snake=False,
|
||||||
|
antialias_activation=False,
|
||||||
|
use_nearest_upsample=False,
|
||||||
|
final_tanh=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
c_mults = [1] + c_mults
|
||||||
|
|
||||||
|
self.depth = len(c_mults)
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(self.depth-1, 0, -1):
|
||||||
|
layers += [DecoderBlock(
|
||||||
|
in_channels=c_mults[i]*channels,
|
||||||
|
out_channels=c_mults[i-1]*channels,
|
||||||
|
stride=strides[i-1],
|
||||||
|
use_snake=use_snake,
|
||||||
|
antialias_activation=antialias_activation,
|
||||||
|
use_nearest_upsample=use_nearest_upsample
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
layers += [
|
||||||
|
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||||
|
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||||
|
nn.Tanh() if final_tanh else nn.Identity()
|
||||||
|
]
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DACEncoderWrapper(nn.Module):
|
||||||
|
def __init__(self, in_channels=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
from dac.model.dac import Encoder as DACEncoder
|
||||||
|
|
||||||
|
latent_dim = kwargs.pop("latent_dim", None)
|
||||||
|
|
||||||
|
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||||
|
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||||
|
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||||
|
|
||||||
|
if in_channels != 1:
|
||||||
|
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class DACDecoderWrapper(nn.Module):
|
||||||
|
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
from dac.model.dac import Decoder as DACDecoder
|
||||||
|
|
||||||
|
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(x)
|
||||||
|
|
||||||
|
class AudioAutoencoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
latent_dim,
|
||||||
|
downsampling_ratio,
|
||||||
|
sample_rate,
|
||||||
|
io_channels=2,
|
||||||
|
bottleneck: Bottleneck = None,
|
||||||
|
pretransform: Pretransform = None,
|
||||||
|
in_channels = None,
|
||||||
|
out_channels = None,
|
||||||
|
soft_clip = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.downsampling_ratio = downsampling_ratio
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.in_channels = io_channels
|
||||||
|
self.out_channels = io_channels
|
||||||
|
|
||||||
|
self.min_length = self.downsampling_ratio
|
||||||
|
|
||||||
|
if in_channels is not None:
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
if out_channels is not None:
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.bottleneck = bottleneck
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
self.pretransform = pretransform
|
||||||
|
|
||||||
|
self.soft_clip = soft_clip
|
||||||
|
|
||||||
|
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||||
|
|
||||||
|
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
if self.pretransform is not None and not skip_pretransform:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
if iterate_batch:
|
||||||
|
audios = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||||
|
audio = torch.cat(audios, dim=0)
|
||||||
|
else:
|
||||||
|
audio = self.pretransform.encode(audio)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if iterate_batch:
|
||||||
|
audios = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||||
|
audio = torch.cat(audios, dim=0)
|
||||||
|
else:
|
||||||
|
audio = self.pretransform.encode(audio)
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
if iterate_batch:
|
||||||
|
latents = []
|
||||||
|
for i in range(audio.shape[0]):
|
||||||
|
latents.append(self.encoder(audio[i:i+1]))
|
||||||
|
latents = torch.cat(latents, dim=0)
|
||||||
|
else:
|
||||||
|
latents = self.encoder(audio)
|
||||||
|
else:
|
||||||
|
latents = audio
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||||
|
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||||
|
|
||||||
|
info.update(bottleneck_info)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return latents, info
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
if iterate_batch:
|
||||||
|
decoded = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||||
|
latents = torch.cat(decoded, dim=0)
|
||||||
|
else:
|
||||||
|
latents = self.bottleneck.decode(latents)
|
||||||
|
|
||||||
|
if iterate_batch:
|
||||||
|
decoded = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decoded.append(self.decoder(latents[i:i+1]))
|
||||||
|
decoded = torch.cat(decoded, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.decoder(latents, **kwargs)
|
||||||
|
|
||||||
|
if self.pretransform is not None:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
if iterate_batch:
|
||||||
|
decodeds = []
|
||||||
|
for i in range(decoded.shape[0]):
|
||||||
|
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||||
|
decoded = torch.cat(decodeds, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if iterate_batch:
|
||||||
|
decodeds = []
|
||||||
|
for i in range(latents.shape[0]):
|
||||||
|
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||||
|
decoded = torch.cat(decodeds, dim=0)
|
||||||
|
else:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
|
||||||
|
if self.soft_clip:
|
||||||
|
decoded = torch.tanh(decoded)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
'''
|
||||||
|
Decode discrete tokens to audio
|
||||||
|
Only works with discrete autoencoders
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||||
|
|
||||||
|
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||||
|
'''
|
||||||
|
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||||
|
If the model is mono, stereo audio will be converted to mono.
|
||||||
|
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||||
|
Audio will be resampled to the model's sample rate.
|
||||||
|
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||||
|
'''
|
||||||
|
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||||
|
|
||||||
|
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||||
|
'''
|
||||||
|
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||||
|
The audio in that list can be of different lengths and channels.
|
||||||
|
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||||
|
All audio will be resampled to the model's sample rate.
|
||||||
|
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||||
|
If the model is mono, all audio will be converted to mono.
|
||||||
|
The output will be a tensor of shape (Batch x Channels x Length)
|
||||||
|
'''
|
||||||
|
batch_size = len(audio_list)
|
||||||
|
if isinstance(in_sr_list, int):
|
||||||
|
in_sr_list = [in_sr_list]*batch_size
|
||||||
|
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||||
|
new_audio = []
|
||||||
|
max_length = 0
|
||||||
|
# resample & find the max length
|
||||||
|
for i in range(batch_size):
|
||||||
|
audio = audio_list[i]
|
||||||
|
in_sr = in_sr_list[i]
|
||||||
|
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||||
|
# batchsize 1 was given by accident. Just squeeze it.
|
||||||
|
audio = audio.squeeze(0)
|
||||||
|
elif len(audio.shape) == 1:
|
||||||
|
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||||
|
audio = audio.unsqueeze(0)
|
||||||
|
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||||
|
# Resample audio
|
||||||
|
if in_sr != self.sample_rate:
|
||||||
|
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||||
|
audio = resample_tf(audio)
|
||||||
|
new_audio.append(audio)
|
||||||
|
if audio.shape[-1] > max_length:
|
||||||
|
max_length = audio.shape[-1]
|
||||||
|
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||||
|
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||||
|
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||||
|
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||||
|
# convert to tensor
|
||||||
|
return torch.stack(new_audio)
|
||||||
|
|
||||||
|
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||||
|
'''
|
||||||
|
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||||
|
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||||
|
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||||
|
# and therefore you likely could use the same values with decode_audio.
|
||||||
|
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||||
|
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||||
|
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||||
|
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||||
|
Smaller chunk_size uses less memory, but more compute.
|
||||||
|
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||||
|
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||||
|
'''
|
||||||
|
if not chunked:
|
||||||
|
# default behavior. Encode the entire audio in parallel
|
||||||
|
return self.encode(audio, **kwargs)
|
||||||
|
else:
|
||||||
|
# CHUNKED ENCODING
|
||||||
|
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||||
|
samples_per_latent = self.downsampling_ratio
|
||||||
|
total_size = audio.shape[2] # in samples
|
||||||
|
batch_size = audio.shape[0]
|
||||||
|
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||||
|
overlap *= samples_per_latent # converting metric in latents to samples
|
||||||
|
hop_size = chunk_size - overlap
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||||
|
chunk = audio[:,:,i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
if i+chunk_size != total_size:
|
||||||
|
# Final chunk
|
||||||
|
chunk = audio[:,:,-chunk_size:]
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunks = torch.stack(chunks)
|
||||||
|
num_chunks = chunks.shape[0]
|
||||||
|
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||||
|
# because we can encode audio of varying lengths
|
||||||
|
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||||
|
y_size = total_size // samples_per_latent
|
||||||
|
# Create an empty latent, we will populate it with chunks as we encode them
|
||||||
|
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||||
|
for i in range(num_chunks):
|
||||||
|
x_chunk = chunks[i,:]
|
||||||
|
# encode the chunk
|
||||||
|
y_chunk = self.encode(x_chunk)
|
||||||
|
# figure out where to put the audio along the time domain
|
||||||
|
if i == num_chunks-1:
|
||||||
|
# final chunk always goes at the end
|
||||||
|
t_end = y_size
|
||||||
|
t_start = t_end - y_chunk.shape[2]
|
||||||
|
else:
|
||||||
|
t_start = i * hop_size // samples_per_latent
|
||||||
|
t_end = t_start + chunk_size // samples_per_latent
|
||||||
|
# remove the edges of the overlaps
|
||||||
|
ol = overlap//samples_per_latent//2
|
||||||
|
chunk_start = 0
|
||||||
|
chunk_end = y_chunk.shape[2]
|
||||||
|
if i > 0:
|
||||||
|
# no overlap for the start of the first chunk
|
||||||
|
t_start += ol
|
||||||
|
chunk_start += ol
|
||||||
|
if i < num_chunks-1:
|
||||||
|
# no overlap for the end of the last chunk
|
||||||
|
t_end -= ol
|
||||||
|
chunk_end -= ol
|
||||||
|
# paste the chunked audio into our y_final output audio
|
||||||
|
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||||
|
return y_final
|
||||||
|
|
||||||
|
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||||
|
'''
|
||||||
|
Decode latents to audio.
|
||||||
|
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||||
|
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||||
|
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||||
|
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||||
|
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||||
|
Smaller chunk_size uses less memory, but more compute.
|
||||||
|
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||||
|
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||||
|
'''
|
||||||
|
if not chunked:
|
||||||
|
# default behavior. Decode the entire latent in parallel
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
else:
|
||||||
|
# chunked decoding
|
||||||
|
hop_size = chunk_size - overlap
|
||||||
|
total_size = latents.shape[2]
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||||
|
chunk = latents[:,:,i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
if i+chunk_size != total_size:
|
||||||
|
# Final chunk
|
||||||
|
chunk = latents[:,:,-chunk_size:]
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunks = torch.stack(chunks)
|
||||||
|
num_chunks = chunks.shape[0]
|
||||||
|
# samples_per_latent is just the downsampling ratio
|
||||||
|
samples_per_latent = self.downsampling_ratio
|
||||||
|
# Create an empty waveform, we will populate it with chunks as decode them
|
||||||
|
y_size = total_size * samples_per_latent
|
||||||
|
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||||
|
for i in range(num_chunks):
|
||||||
|
x_chunk = chunks[i,:]
|
||||||
|
# decode the chunk
|
||||||
|
y_chunk = self.decode(x_chunk)
|
||||||
|
# figure out where to put the audio along the time domain
|
||||||
|
if i == num_chunks-1:
|
||||||
|
# final chunk always goes at the end
|
||||||
|
t_end = y_size
|
||||||
|
t_start = t_end - y_chunk.shape[2]
|
||||||
|
else:
|
||||||
|
t_start = i * hop_size * samples_per_latent
|
||||||
|
t_end = t_start + chunk_size * samples_per_latent
|
||||||
|
# remove the edges of the overlaps
|
||||||
|
ol = (overlap//2) * samples_per_latent
|
||||||
|
chunk_start = 0
|
||||||
|
chunk_end = y_chunk.shape[2]
|
||||||
|
if i > 0:
|
||||||
|
# no overlap for the start of the first chunk
|
||||||
|
t_start += ol
|
||||||
|
chunk_start += ol
|
||||||
|
if i < num_chunks-1:
|
||||||
|
# no overlap for the end of the last chunk
|
||||||
|
t_end -= ol
|
||||||
|
chunk_end -= ol
|
||||||
|
# paste the chunked audio into our y_final output audio
|
||||||
|
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||||
|
return y_final
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionAutoencoder(AudioAutoencoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diffusion: ConditionedDiffusionModel,
|
||||||
|
diffusion_downsampling_ratio,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.diffusion = diffusion
|
||||||
|
|
||||||
|
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
# Shrink the initial encoder parameters to avoid saturated latents
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.encoder.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def decode(self, latents, steps=100):
|
||||||
|
|
||||||
|
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||||
|
|
||||||
|
if self.bottleneck is not None:
|
||||||
|
latents = self.bottleneck.decode(latents)
|
||||||
|
|
||||||
|
if self.decoder is not None:
|
||||||
|
latents = self.decoder(latents)
|
||||||
|
|
||||||
|
# Upsample latents to match diffusion length
|
||||||
|
if latents.shape[2] != upsampled_length:
|
||||||
|
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||||
|
|
||||||
|
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||||
|
from prismaudio_core.inference.sampling import sample
|
||||||
|
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||||
|
|
||||||
|
if self.pretransform is not None:
|
||||||
|
if self.pretransform.enable_grad:
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
decoded = self.pretransform.decode(decoded)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
# AE factories
|
||||||
|
|
||||||
|
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||||
|
encoder_type = encoder_config.get("type", None)
|
||||||
|
assert encoder_type is not None, "Encoder type must be specified"
|
||||||
|
|
||||||
|
if encoder_type == "oobleck":
|
||||||
|
encoder = OobleckEncoder(
|
||||||
|
**encoder_config["config"]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif encoder_type == "seanet":
|
||||||
|
from encodec.modules import SEANetEncoder
|
||||||
|
seanet_encoder_config = encoder_config["config"]
|
||||||
|
|
||||||
|
#SEANet encoder expects strides in reverse order
|
||||||
|
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||||
|
encoder = SEANetEncoder(
|
||||||
|
**seanet_encoder_config
|
||||||
|
)
|
||||||
|
elif encoder_type == "dac":
|
||||||
|
dac_config = encoder_config["config"]
|
||||||
|
|
||||||
|
encoder = DACEncoderWrapper(**dac_config)
|
||||||
|
elif encoder_type == "local_attn":
|
||||||
|
from .local_attention import TransformerEncoder1D
|
||||||
|
|
||||||
|
local_attn_config = encoder_config["config"]
|
||||||
|
|
||||||
|
encoder = TransformerEncoder1D(
|
||||||
|
**local_attn_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||||
|
|
||||||
|
requires_grad = encoder_config.get("requires_grad", True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||||
|
decoder_type = decoder_config.get("type", None)
|
||||||
|
assert decoder_type is not None, "Decoder type must be specified"
|
||||||
|
|
||||||
|
if decoder_type == "oobleck":
|
||||||
|
decoder = OobleckDecoder(
|
||||||
|
**decoder_config["config"]
|
||||||
|
)
|
||||||
|
elif decoder_type == "seanet":
|
||||||
|
from encodec.modules import SEANetDecoder
|
||||||
|
|
||||||
|
decoder = SEANetDecoder(
|
||||||
|
**decoder_config["config"]
|
||||||
|
)
|
||||||
|
elif decoder_type == "dac":
|
||||||
|
dac_config = decoder_config["config"]
|
||||||
|
|
||||||
|
decoder = DACDecoderWrapper(**dac_config)
|
||||||
|
elif decoder_type == "local_attn":
|
||||||
|
from .local_attention import TransformerDecoder1D
|
||||||
|
|
||||||
|
local_attn_config = decoder_config["config"]
|
||||||
|
|
||||||
|
decoder = TransformerDecoder1D(
|
||||||
|
**local_attn_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||||
|
|
||||||
|
requires_grad = decoder_config.get("requires_grad", True)
|
||||||
|
if not requires_grad:
|
||||||
|
for param in decoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||||
|
|
||||||
|
ae_config = config["model"]
|
||||||
|
|
||||||
|
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||||
|
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||||
|
|
||||||
|
bottleneck = ae_config.get("bottleneck", None)
|
||||||
|
|
||||||
|
latent_dim = ae_config.get("latent_dim", None)
|
||||||
|
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||||
|
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||||
|
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||||
|
io_channels = ae_config.get("io_channels", None)
|
||||||
|
assert io_channels is not None, "io_channels must be specified in model config"
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||||
|
|
||||||
|
in_channels = ae_config.get("in_channels", None)
|
||||||
|
out_channels = ae_config.get("out_channels", None)
|
||||||
|
|
||||||
|
pretransform = ae_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
if bottleneck is not None:
|
||||||
|
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||||
|
|
||||||
|
return AudioAutoencoder(
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
io_channels=io_channels,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
downsampling_ratio=downsampling_ratio,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
pretransform=pretransform,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
soft_clip=soft_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||||
|
|
||||||
|
diffae_config = config["model"]
|
||||||
|
|
||||||
|
if "encoder" in diffae_config:
|
||||||
|
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||||
|
else:
|
||||||
|
encoder = None
|
||||||
|
|
||||||
|
if "decoder" in diffae_config:
|
||||||
|
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||||
|
else:
|
||||||
|
decoder = None
|
||||||
|
|
||||||
|
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||||
|
|
||||||
|
if diffusion_model_type == "DAU1d":
|
||||||
|
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
elif diffusion_model_type == "adp_1d":
|
||||||
|
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||||
|
|
||||||
|
latent_dim = diffae_config.get("latent_dim", None)
|
||||||
|
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||||
|
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||||
|
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||||
|
io_channels = diffae_config.get("io_channels", None)
|
||||||
|
assert io_channels is not None, "io_channels must be specified in model config"
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||||
|
|
||||||
|
bottleneck = diffae_config.get("bottleneck", None)
|
||||||
|
|
||||||
|
pretransform = diffae_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||||
|
|
||||||
|
if bottleneck is not None:
|
||||||
|
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||||
|
|
||||||
|
diffusion_downsampling_ratio = None
|
||||||
|
|
||||||
|
if diffusion_model_type == "DAU1d":
|
||||||
|
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||||
|
elif diffusion_model_type == "adp_1d":
|
||||||
|
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
diffusion_downsampling_ratio = 1
|
||||||
|
|
||||||
|
return DiffusionAutoencoder(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
diffusion=diffusion,
|
||||||
|
io_channels=io_channels,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
latent_dim=latent_dim,
|
||||||
|
downsampling_ratio=downsampling_ratio,
|
||||||
|
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
pretransform=pretransform
|
||||||
|
)
|
||||||
@@ -0,0 +1,331 @@
|
|||||||
|
from functools import reduce
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from torch.backends.cuda import sdp_kernel
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from dac.nn.layers import Snake1d
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
def __init__(self, main, skip=None):
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(*main)
|
||||||
|
self.skip = skip if skip else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.main(input) + self.skip(input)
|
||||||
|
|
||||||
|
class ResConvBlock(ResidualBlock):
|
||||||
|
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
||||||
|
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
||||||
|
super().__init__([
|
||||||
|
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||||
|
nn.GroupNorm(1, c_mid),
|
||||||
|
Snake1d(c_mid) if use_snake else nn.GELU(),
|
||||||
|
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||||
|
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
||||||
|
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
||||||
|
], skip)
|
||||||
|
|
||||||
|
class SelfAttention1d(nn.Module):
|
||||||
|
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
||||||
|
super().__init__()
|
||||||
|
assert c_in % n_head == 0
|
||||||
|
self.norm = nn.GroupNorm(1, c_in)
|
||||||
|
self.n_head = n_head
|
||||||
|
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
||||||
|
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
||||||
|
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||||
|
|
||||||
|
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||||
|
|
||||||
|
if not self.use_flash:
|
||||||
|
return
|
||||||
|
|
||||||
|
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||||
|
|
||||||
|
if device_properties.major == 8 and device_properties.minor == 0:
|
||||||
|
# Use flash attention for A100 GPUs
|
||||||
|
self.sdp_kernel_config = (True, False, False)
|
||||||
|
else:
|
||||||
|
# Don't use flash attention for other GPUs
|
||||||
|
self.sdp_kernel_config = (False, True, True)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
n, c, s = input.shape
|
||||||
|
qkv = self.qkv_proj(self.norm(input))
|
||||||
|
qkv = qkv.view(
|
||||||
|
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
scale = k.shape[3]**-0.25
|
||||||
|
|
||||||
|
if self.use_flash:
|
||||||
|
with sdp_kernel(*self.sdp_kernel_config):
|
||||||
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
||||||
|
else:
|
||||||
|
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||||
|
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
||||||
|
|
||||||
|
|
||||||
|
return input + self.dropout(self.out_proj(y))
|
||||||
|
|
||||||
|
class SkipBlock(nn.Module):
|
||||||
|
def __init__(self, *main):
|
||||||
|
super().__init__()
|
||||||
|
self.main = nn.Sequential(*main)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.cat([self.main(input), input], dim=1)
|
||||||
|
|
||||||
|
class FourierFeatures(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, std=1.):
|
||||||
|
super().__init__()
|
||||||
|
assert out_features % 2 == 0
|
||||||
|
self.weight = nn.Parameter(torch.randn(
|
||||||
|
[out_features // 2, in_features]) * std)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
f = 2 * math.pi * input @ self.weight.T
|
||||||
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||||
|
|
||||||
|
def expand_to_planes(input, shape):
|
||||||
|
return input[..., None].repeat([1, 1, shape[2]])
|
||||||
|
|
||||||
|
_kernels = {
|
||||||
|
'linear':
|
||||||
|
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||||
|
'cubic':
|
||||||
|
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||||
|
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||||
|
'lanczos3':
|
||||||
|
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||||
|
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||||
|
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||||
|
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||||
|
}
|
||||||
|
|
||||||
|
class Downsample1d(nn.Module):
|
||||||
|
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
kernel_1d = torch.tensor(_kernels[kernel])
|
||||||
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||||
|
self.register_buffer('kernel', kernel_1d)
|
||||||
|
self.channels_last = channels_last
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
||||||
|
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||||
|
indices = torch.arange(x.shape[1], device=x.device)
|
||||||
|
weight[indices, indices] = self.kernel.to(weight)
|
||||||
|
x = F.conv1d(x, weight, stride=2)
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1d(nn.Module):
|
||||||
|
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||||
|
super().__init__()
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||||
|
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||||
|
self.register_buffer('kernel', kernel_1d)
|
||||||
|
self.channels_last = channels_last
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||||
|
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||||
|
indices = torch.arange(x.shape[1], device=x.device)
|
||||||
|
weight[indices, indices] = self.kernel.to(weight)
|
||||||
|
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||||
|
if self.channels_last:
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def Downsample1d_2(
|
||||||
|
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
||||||
|
) -> nn.Module:
|
||||||
|
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
||||||
|
|
||||||
|
return nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=factor * kernel_multiplier + 1,
|
||||||
|
stride=factor,
|
||||||
|
padding=factor * (kernel_multiplier // 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def Upsample1d_2(
|
||||||
|
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
||||||
|
) -> nn.Module:
|
||||||
|
|
||||||
|
if factor == 1:
|
||||||
|
return nn.Conv1d(
|
||||||
|
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_nearest:
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=factor, mode="nearest"),
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.ConvTranspose1d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=factor * 2,
|
||||||
|
stride=factor,
|
||||||
|
padding=factor // 2 + factor % 2,
|
||||||
|
output_padding=factor % 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def zero_init(layer):
|
||||||
|
nn.init.zeros_(layer.weight)
|
||||||
|
if layer.bias is not None:
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
class AdaRMSNorm(nn.Module):
|
||||||
|
def __init__(self, features, cond_features, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"eps={self.eps},"
|
||||||
|
|
||||||
|
def forward(self, x, cond):
|
||||||
|
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
||||||
|
|
||||||
|
def normalize(x, eps=1e-4):
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
||||||
|
alpha = np.sqrt(n.numel() / x.numel())
|
||||||
|
return x / torch.add(eps, n, alpha=alpha)
|
||||||
|
|
||||||
|
class ForcedWNConv1d(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight.copy_(normalize(self.weight))
|
||||||
|
|
||||||
|
fan_in = self.weight[0].numel()
|
||||||
|
|
||||||
|
w = normalize(self.weight) / math.sqrt(fan_in)
|
||||||
|
|
||||||
|
return F.conv1d(x, w, padding='same')
|
||||||
|
|
||||||
|
# Kernels
|
||||||
|
|
||||||
|
use_compile = True
|
||||||
|
|
||||||
|
def compile(function, *args, **kwargs):
|
||||||
|
if not use_compile:
|
||||||
|
return function
|
||||||
|
try:
|
||||||
|
return torch.compile(function, *args, **kwargs)
|
||||||
|
except RuntimeError:
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def linear_geglu(x, weight, bias=None):
|
||||||
|
x = x @ weight.mT
|
||||||
|
if bias is not None:
|
||||||
|
x = x + bias
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return x * F.gelu(gate)
|
||||||
|
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def rms_norm(x, scale, eps):
|
||||||
|
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||||
|
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||||
|
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||||
|
return x * scale.to(x.dtype)
|
||||||
|
|
||||||
|
# Layers
|
||||||
|
|
||||||
|
class LinearGEGLU(nn.Linear):
|
||||||
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
super().__init__(in_features, out_features * 2, bias=bias)
|
||||||
|
self.out_features = out_features
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return linear_geglu(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if fix_scale:
|
||||||
|
self.register_buffer("scale", torch.ones(shape))
|
||||||
|
else:
|
||||||
|
self.scale = nn.Parameter(torch.ones(shape))
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return rms_norm(x, self.scale, self.eps)
|
||||||
|
|
||||||
|
def snake_beta(x, alpha, beta):
|
||||||
|
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# snake_beta = torch.compile(snake_beta)
|
||||||
|
# except RuntimeError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||||
|
# License available in LICENSES/LICENSE_NVIDIA.txt
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = snake_beta(x, alpha, beta)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,355 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||||
|
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
def __init__(self, is_discrete: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_discrete = is_discrete
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class DiscreteBottleneck(Bottleneck):
|
||||||
|
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||||
|
super().__init__(is_discrete=True)
|
||||||
|
|
||||||
|
self.num_quantizers = num_quantizers
|
||||||
|
self.codebook_size = codebook_size
|
||||||
|
self.tokens_id = tokens_id
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class TanhBottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def vae_sample(mean, scale):
|
||||||
|
stdev = nn.functional.softplus(scale) + 1e-4
|
||||||
|
var = stdev * stdev
|
||||||
|
logvar = torch.log(var)
|
||||||
|
latents = torch.randn_like(mean) * stdev + mean
|
||||||
|
|
||||||
|
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||||
|
|
||||||
|
return latents, kl
|
||||||
|
|
||||||
|
class VAEBottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def compute_mean_kernel(x, y):
|
||||||
|
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||||
|
return torch.exp(-kernel_input).mean()
|
||||||
|
|
||||||
|
def compute_mmd(latents):
|
||||||
|
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||||
|
noise = torch.randn_like(latents_reshaped)
|
||||||
|
|
||||||
|
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||||
|
noise_kernel = compute_mean_kernel(noise, noise)
|
||||||
|
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||||
|
|
||||||
|
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||||
|
return mmd.mean()
|
||||||
|
|
||||||
|
class WassersteinBottleneck(Bottleneck):
|
||||||
|
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
self.bypass_mmd = bypass_mmd
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
if self.training and return_info:
|
||||||
|
if self.bypass_mmd:
|
||||||
|
mmd = torch.tensor(0.0)
|
||||||
|
else:
|
||||||
|
mmd = compute_mmd(x)
|
||||||
|
|
||||||
|
info["mmd"] = mmd
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class L2Bottleneck(Bottleneck):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(is_discrete=False)
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = F.normalize(x, dim=1)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return F.normalize(x, dim=1)
|
||||||
|
|
||||||
|
class RVQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||||
|
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices, loss = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
info["quantizer_loss"] = loss.mean()
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||||
|
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||||
|
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices, loss = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
info["quantizer_loss"] = loss.mean()
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class DACRVQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||||
|
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, **kwargs):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
info["pre_quantizer"] = x
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
return x, info if return_info else x
|
||||||
|
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"z": z,
|
||||||
|
"codes": codes,
|
||||||
|
"latents": latents,
|
||||||
|
"vq/commitment_loss": commitment_loss,
|
||||||
|
"vq/codebook_loss": codebook_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
output["vq/commitment_loss"] /= self.num_quantizers
|
||||||
|
output["vq/codebook_loss"] /= self.num_quantizers
|
||||||
|
|
||||||
|
info.update(output)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output["z"], info
|
||||||
|
|
||||||
|
return output["z"]
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
x = self.quantizer(x)[0]
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents, _, _ = self.quantizer.from_codes(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||||
|
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||||
|
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||||
|
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
mean, scale = x.chunk(2, dim=1)
|
||||||
|
|
||||||
|
x, kl = vae_sample(mean, scale)
|
||||||
|
|
||||||
|
info["pre_quantizer"] = x
|
||||||
|
info["kl"] = kl
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
return x, info if return_info else x
|
||||||
|
|
||||||
|
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"z": z,
|
||||||
|
"codes": codes,
|
||||||
|
"latents": latents,
|
||||||
|
"vq/commitment_loss": commitment_loss,
|
||||||
|
"vq/codebook_loss": codebook_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
output["vq/commitment_loss"] /= self.num_quantizers
|
||||||
|
output["vq/codebook_loss"] /= self.num_quantizers
|
||||||
|
|
||||||
|
info.update(output)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output["z"], info
|
||||||
|
|
||||||
|
return output["z"]
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
x = self.quantizer(x)[0]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, codes, **kwargs):
|
||||||
|
latents, _, _ = self.quantizer.from_codes(codes)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
|
|
||||||
|
class FSQBottleneck(DiscreteBottleneck):
|
||||||
|
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||||
|
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||||
|
|
||||||
|
self.noise_augment_dim = noise_augment_dim
|
||||||
|
|
||||||
|
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||||
|
|
||||||
|
def encode(self, x, return_info=False):
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x, indices = self.quantizer(x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
|
||||||
|
# Reorder indices to match the expected format
|
||||||
|
indices = rearrange(indices, "b n q -> b q n")
|
||||||
|
|
||||||
|
info["quantizer_indices"] = indices
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
|
||||||
|
if self.noise_augment_dim > 0:
|
||||||
|
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||||
|
x.shape[-1]).type_as(x)
|
||||||
|
x = torch.cat([x, noise], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
latents = self.quantizer.indices_to_codes(tokens)
|
||||||
|
|
||||||
|
return self.decode(latents, **kwargs)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,884 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
||||||
|
from .conditioners import MultiConditioner
|
||||||
|
from .dit import DiffusionTransformer
|
||||||
|
from .pretransforms import Pretransform
|
||||||
|
|
||||||
|
from .adp import UNetCFG1d, UNet1d
|
||||||
|
|
||||||
|
# Lazy imports for factory functions to avoid circular imports
|
||||||
|
def _get_create_pretransform_from_config():
|
||||||
|
from prismaudio_core.factory import create_pretransform_from_config
|
||||||
|
return create_pretransform_from_config
|
||||||
|
|
||||||
|
def _get_create_multi_conditioner_from_conditioning_config():
|
||||||
|
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||||
|
return create_multi_conditioner_from_conditioning_config
|
||||||
|
|
||||||
|
class DiffusionModel(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class DiffusionModelWrapper(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: DiffusionModel,
|
||||||
|
io_channels,
|
||||||
|
sample_size,
|
||||||
|
sample_rate,
|
||||||
|
min_input_length,
|
||||||
|
pretransform: tp.Optional[Pretransform] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
self.pretransform = pretransform
|
||||||
|
else:
|
||||||
|
self.pretransform = None
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
class ConditionedDiffusionModel(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
supports_cross_attention: bool = False,
|
||||||
|
supports_input_concat: bool = False,
|
||||||
|
supports_global_cond: bool = False,
|
||||||
|
supports_prepend_cond: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.supports_cross_attention = supports_cross_attention
|
||||||
|
self.supports_input_concat = supports_input_concat
|
||||||
|
self.supports_global_cond = supports_global_cond
|
||||||
|
self.supports_prepend_cond = supports_prepend_cond
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cross_attn_cond: torch.Tensor = None,
|
||||||
|
cross_attn_mask: torch.Tensor = None,
|
||||||
|
input_concat_cond: torch.Tensor = None,
|
||||||
|
global_embed: torch.Tensor = None,
|
||||||
|
prepend_cond: torch.Tensor = None,
|
||||||
|
prepend_cond_mask: torch.Tensor = None,
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class ConditionedDiffusionModelWrapper(nn.Module):
|
||||||
|
"""
|
||||||
|
A diffusion model that takes in conditioning
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: ConditionedDiffusionModel,
|
||||||
|
conditioner: MultiConditioner,
|
||||||
|
io_channels,
|
||||||
|
sample_rate,
|
||||||
|
min_input_length: int,
|
||||||
|
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||||
|
zero_init: bool = False,
|
||||||
|
pretransform: tp.Optional[Pretransform] = None,
|
||||||
|
cross_attn_cond_ids: tp.List[str] = [],
|
||||||
|
global_cond_ids: tp.List[str] = [],
|
||||||
|
input_concat_ids: tp.List[str] = [],
|
||||||
|
prepend_cond_ids: tp.List[str] = [],
|
||||||
|
add_cond_ids: tp.List[str] = [],
|
||||||
|
sync_cond_ids: tp.List[str] = [],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.conditioner = conditioner
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.diffusion_objective = diffusion_objective
|
||||||
|
self.pretransform = pretransform
|
||||||
|
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||||
|
self.global_cond_ids = global_cond_ids
|
||||||
|
self.input_concat_ids = input_concat_ids
|
||||||
|
self.prepend_cond_ids = prepend_cond_ids
|
||||||
|
self.add_cond_ids = add_cond_ids
|
||||||
|
self.sync_cond_ids = sync_cond_ids
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
if zero_init is True:
|
||||||
|
self.conditioner.apply(_basic_init)
|
||||||
|
self.model.model.initialize_weights()
|
||||||
|
|
||||||
|
|
||||||
|
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||||
|
cross_attention_input = None
|
||||||
|
cross_attention_masks = None
|
||||||
|
global_cond = None
|
||||||
|
input_concat_cond = None
|
||||||
|
prepend_cond = None
|
||||||
|
prepend_cond_mask = None
|
||||||
|
add_input = None
|
||||||
|
sync_input = None
|
||||||
|
|
||||||
|
if len(self.cross_attn_cond_ids) > 0:
|
||||||
|
# Concatenate all cross-attention inputs over the sequence dimension
|
||||||
|
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||||
|
cross_attention_input = []
|
||||||
|
cross_attention_masks = []
|
||||||
|
|
||||||
|
for key in self.cross_attn_cond_ids:
|
||||||
|
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
||||||
|
|
||||||
|
# Add sequence dimension if it's not there
|
||||||
|
if len(cross_attn_in.shape) == 2:
|
||||||
|
cross_attn_in = cross_attn_in.unsqueeze(1)
|
||||||
|
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
cross_attention_input.append(cross_attn_in)
|
||||||
|
cross_attention_masks.append(cross_attn_mask)
|
||||||
|
|
||||||
|
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||||
|
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||||
|
|
||||||
|
if len(self.add_cond_ids) > 0:
|
||||||
|
# Concatenate all cross-attention inputs over the sequence dimension
|
||||||
|
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||||
|
add_input = []
|
||||||
|
|
||||||
|
for key in self.add_cond_ids:
|
||||||
|
add_in = conditioning_tensors[key][0]
|
||||||
|
|
||||||
|
# Add sequence dimension if it's not there
|
||||||
|
if len(add_in.shape) == 2:
|
||||||
|
add_in = add_in.unsqueeze(1)
|
||||||
|
# add_in = add_in.transpose(1,2)
|
||||||
|
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
|
||||||
|
# add_in = add_in.transpose(1,2)
|
||||||
|
add_input.append(add_in)
|
||||||
|
|
||||||
|
add_input = torch.cat(add_input, dim=2)
|
||||||
|
|
||||||
|
if len(self.sync_cond_ids) > 0:
|
||||||
|
# Concatenate all cross-attention inputs over the sequence dimension
|
||||||
|
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||||
|
sync_input = []
|
||||||
|
|
||||||
|
for key in self.sync_cond_ids:
|
||||||
|
sync_in = conditioning_tensors[key][0]
|
||||||
|
|
||||||
|
# Add sequence dimension if it's not there
|
||||||
|
if len(sync_in.shape) == 2:
|
||||||
|
sync_in = sync_in.unsqueeze(1)
|
||||||
|
sync_input.append(sync_in)
|
||||||
|
|
||||||
|
sync_input = torch.cat(sync_input, dim=2)
|
||||||
|
|
||||||
|
if len(self.global_cond_ids) > 0:
|
||||||
|
# Concatenate all global conditioning inputs over the channel dimension
|
||||||
|
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||||
|
global_conds = []
|
||||||
|
for key in self.global_cond_ids:
|
||||||
|
global_cond_input = conditioning_tensors[key][0]
|
||||||
|
if len(global_cond_input.shape) == 2:
|
||||||
|
global_cond_input = global_cond_input.unsqueeze(1)
|
||||||
|
global_conds.append(global_cond_input)
|
||||||
|
|
||||||
|
# # Concatenate over the channel dimension
|
||||||
|
# if global_conds[0].shape[-1] == 768:
|
||||||
|
# global_cond = torch.cat(global_conds, dim=-1)
|
||||||
|
# else:
|
||||||
|
# global_cond = sum(global_conds)
|
||||||
|
global_cond = sum(global_conds)
|
||||||
|
# global_cond = torch.cat(global_conds, dim=-1)
|
||||||
|
|
||||||
|
if len(global_cond.shape) == 3:
|
||||||
|
global_cond = global_cond.squeeze(1)
|
||||||
|
|
||||||
|
if len(self.input_concat_ids) > 0:
|
||||||
|
# Concatenate all input concat conditioning inputs over the channel dimension
|
||||||
|
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
||||||
|
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
||||||
|
|
||||||
|
if len(self.prepend_cond_ids) > 0:
|
||||||
|
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||||
|
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||||
|
prepend_conds = []
|
||||||
|
prepend_cond_masks = []
|
||||||
|
|
||||||
|
for key in self.prepend_cond_ids:
|
||||||
|
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
||||||
|
if len(prepend_cond_input.shape) == 2:
|
||||||
|
prepend_cond_input = prepend_cond_input.unsqueeze(1)
|
||||||
|
prepend_conds.append(prepend_cond_input)
|
||||||
|
prepend_cond_masks.append(prepend_cond_mask)
|
||||||
|
|
||||||
|
prepend_cond = torch.cat(prepend_conds, dim=1)
|
||||||
|
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
||||||
|
|
||||||
|
if negative:
|
||||||
|
return {
|
||||||
|
"negative_cross_attn_cond": cross_attention_input,
|
||||||
|
"negative_cross_attn_mask": cross_attention_masks,
|
||||||
|
"negative_global_cond": global_cond,
|
||||||
|
"negative_input_concat_cond": input_concat_cond
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"cross_attn_cond": cross_attention_input,
|
||||||
|
"cross_attn_mask": cross_attention_masks,
|
||||||
|
"global_cond": global_cond,
|
||||||
|
"input_concat_cond": input_concat_cond,
|
||||||
|
"prepend_cond": prepend_cond,
|
||||||
|
"prepend_cond_mask": prepend_cond_mask,
|
||||||
|
"add_cond": add_input,
|
||||||
|
"sync_cond": sync_input
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||||
|
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||||
|
return generate_diffusion_cond(self, *args, **kwargs)
|
||||||
|
|
||||||
|
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = UNetCFG1d(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
**kwargs):
|
||||||
|
channels_list = None
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
channels_list = [input_concat_cond]
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
embedding=cross_attn_cond,
|
||||||
|
embedding_mask=cross_attn_mask,
|
||||||
|
features=global_cond,
|
||||||
|
channels_list=channels_list,
|
||||||
|
embedding_scale=cfg_scale,
|
||||||
|
embedding_mask_proba=cfg_dropout_prob,
|
||||||
|
batch_cfg=batch_cfg,
|
||||||
|
rescale_cfg=rescale_cfg,
|
||||||
|
negative_embedding=negative_cross_attn_cond,
|
||||||
|
negative_embedding_mask=negative_cross_attn_mask,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = UNet1d(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
channels_list = None
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
if input_concat_cond.shape[2] != x.shape[2]:
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
|
||||||
|
channels_list = [input_concat_cond]
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
features=global_cond,
|
||||||
|
channels_list=channels_list,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class UNet1DUncondWrapper(DiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
||||||
|
|
||||||
|
self.io_channels = in_channels
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
||||||
|
|
||||||
|
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
input_concat_cond=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
global_cond=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = False,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
return self.model(x, t, cond = input_concat_cond)
|
||||||
|
|
||||||
|
class DiffusionAttnUnet1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
io_channels = 2,
|
||||||
|
depth=14,
|
||||||
|
n_attn_layers = 6,
|
||||||
|
channels = [128, 128, 256, 256] + [512] * 10,
|
||||||
|
cond_dim = 0,
|
||||||
|
cond_noise_aug = False,
|
||||||
|
kernel_size = 5,
|
||||||
|
learned_resample = False,
|
||||||
|
strides = [2] * 13,
|
||||||
|
conv_bias = True,
|
||||||
|
use_snake = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_noise_aug = cond_noise_aug
|
||||||
|
|
||||||
|
self.io_channels = io_channels
|
||||||
|
|
||||||
|
if self.cond_noise_aug:
|
||||||
|
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||||
|
|
||||||
|
self.timestep_embed = FourierFeatures(1, 16)
|
||||||
|
|
||||||
|
attn_layer = depth - n_attn_layers
|
||||||
|
|
||||||
|
strides = [1] + strides
|
||||||
|
|
||||||
|
block = nn.Identity()
|
||||||
|
|
||||||
|
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
||||||
|
|
||||||
|
for i in range(depth, 0, -1):
|
||||||
|
c = channels[i - 1]
|
||||||
|
stride = strides[i-1]
|
||||||
|
if stride > 2 and not learned_resample:
|
||||||
|
raise ValueError("Must have stride 2 without learned resampling")
|
||||||
|
|
||||||
|
if i > 1:
|
||||||
|
c_prev = channels[i - 2]
|
||||||
|
add_attn = i >= attn_layer and n_attn_layers > 0
|
||||||
|
block = SkipBlock(
|
||||||
|
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
||||||
|
conv_block(c_prev, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
block,
|
||||||
|
conv_block(c * 2 if i != depth else c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
SelfAttention1d(
|
||||||
|
c, c // 32) if add_attn else nn.Identity(),
|
||||||
|
conv_block(c, c, c_prev),
|
||||||
|
SelfAttention1d(c_prev, c_prev //
|
||||||
|
32) if add_attn else nn.Identity(),
|
||||||
|
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
||||||
|
block = nn.Sequential(
|
||||||
|
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
block,
|
||||||
|
conv_block(c * 2, c, c),
|
||||||
|
conv_block(c, c, c),
|
||||||
|
conv_block(c, c, io_channels, is_last=True),
|
||||||
|
)
|
||||||
|
self.net = block
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.net.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
||||||
|
|
||||||
|
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
||||||
|
|
||||||
|
inputs = [x, timestep_embed]
|
||||||
|
|
||||||
|
if cond is not None:
|
||||||
|
if cond.shape[2] != x.shape[2]:
|
||||||
|
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
||||||
|
|
||||||
|
if self.cond_noise_aug:
|
||||||
|
# Get a random number between 0 and 1, uniformly sampled
|
||||||
|
if cond_aug_scale is None:
|
||||||
|
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
||||||
|
else:
|
||||||
|
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
||||||
|
|
||||||
|
# Add noise to the conditioning signal
|
||||||
|
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
||||||
|
|
||||||
|
# Get embedding for noise cond level, reusing timestamp_embed
|
||||||
|
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
||||||
|
|
||||||
|
inputs.append(aug_level_embed)
|
||||||
|
|
||||||
|
inputs.append(cond)
|
||||||
|
|
||||||
|
outputs = self.net(torch.cat(inputs, dim=1))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class DiTWrapper(ConditionedDiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||||
|
|
||||||
|
self.model = DiffusionTransformer(*args, **kwargs)
|
||||||
|
# with torch.no_grad():
|
||||||
|
# for param in self.model.parameters():
|
||||||
|
# param *= 0.5
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_mask=None,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
negative_input_concat_cond=None,
|
||||||
|
global_cond=None,
|
||||||
|
negative_global_cond=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob: float = 0.0,
|
||||||
|
batch_cfg: bool = True,
|
||||||
|
rescale_cfg: bool = False,
|
||||||
|
scale_phi: float = 0.0,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||||
|
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||||
|
|
||||||
|
return self.model(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=cross_attn_cond,
|
||||||
|
cross_attn_cond_mask=cross_attn_mask,
|
||||||
|
negative_cross_attn_cond=negative_cross_attn_cond,
|
||||||
|
negative_cross_attn_mask=negative_cross_attn_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
cfg_dropout_prob=cfg_dropout_prob,
|
||||||
|
scale_phi=scale_phi,
|
||||||
|
global_embed=global_cond,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||||
|
"""
|
||||||
|
A diffusion model that takes in conditioning
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
conditioner: MultiConditioner,
|
||||||
|
io_channels,
|
||||||
|
sample_rate,
|
||||||
|
min_input_length: int,
|
||||||
|
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||||
|
pretransform: tp.Optional[Pretransform] = None,
|
||||||
|
cross_attn_cond_ids: tp.List[str] = [],
|
||||||
|
global_cond_ids: tp.List[str] = [],
|
||||||
|
input_concat_ids: tp.List[str] = [],
|
||||||
|
prepend_cond_ids: tp.List[str] = [],
|
||||||
|
add_cond_ids: tp.List[str] = [],
|
||||||
|
mm_cond_ids: tp.List[str] = [],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.conditioner = conditioner
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.diffusion_objective = diffusion_objective
|
||||||
|
self.pretransform = pretransform
|
||||||
|
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||||
|
self.global_cond_ids = global_cond_ids
|
||||||
|
self.input_concat_ids = input_concat_ids
|
||||||
|
self.prepend_cond_ids = prepend_cond_ids
|
||||||
|
self.add_cond_ids = add_cond_ids
|
||||||
|
self.min_input_length = min_input_length
|
||||||
|
self.mm_cond_ids = mm_cond_ids
|
||||||
|
|
||||||
|
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
|
||||||
|
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
|
||||||
|
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
|
||||||
|
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
|
||||||
|
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
|
||||||
|
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
|
||||||
|
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
|
||||||
|
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||||
|
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||||
|
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
|
||||||
|
|
||||||
|
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||||
|
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
|
||||||
|
cross_attention_input = None
|
||||||
|
cross_attention_masks = None
|
||||||
|
global_cond = None
|
||||||
|
input_concat_cond = None
|
||||||
|
prepend_cond = None
|
||||||
|
prepend_cond_mask = None
|
||||||
|
add_input = None
|
||||||
|
inpaint_masked_input = None
|
||||||
|
t5_features = None
|
||||||
|
metaclip_global_text_features = None
|
||||||
|
clip_f = conditioning_tensors["metaclip_features"]
|
||||||
|
sync_f = conditioning_tensors["sync_features"]
|
||||||
|
text_f = conditioning_tensors["metaclip_text_features"]
|
||||||
|
if 'inpaint_masked_input' in conditioning_tensors.keys():
|
||||||
|
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
|
||||||
|
if 't5_features' in conditioning_tensors.keys():
|
||||||
|
t5_features = conditioning_tensors["t5_features"]
|
||||||
|
if 'metaclip_global_text_features' in conditioning_tensors.keys():
|
||||||
|
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
|
||||||
|
return {
|
||||||
|
"clip_f": clip_f,
|
||||||
|
"sync_f": sync_f,
|
||||||
|
"text_f": text_f,
|
||||||
|
"inpaint_masked_input": inpaint_masked_input,
|
||||||
|
"t5_features": t5_features,
|
||||||
|
"metaclip_global_text_features": metaclip_global_text_features
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||||
|
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs):
|
||||||
|
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||||
|
return generate_diffusion_cond(self, *args, **kwargs)
|
||||||
|
|
||||||
|
class DiTUncondWrapper(DiffusionModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
io_channels,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
|
||||||
|
|
||||||
|
self.io_channels = io_channels
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param *= 0.5
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
return self.model(x, t, **kwargs)
|
||||||
|
|
||||||
|
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
diffusion_uncond_config = config["model"]
|
||||||
|
|
||||||
|
model_type = diffusion_uncond_config.get('type', None)
|
||||||
|
|
||||||
|
diffusion_config = diffusion_uncond_config.get('config', {})
|
||||||
|
|
||||||
|
assert model_type is not None, "Must specify model type in config"
|
||||||
|
|
||||||
|
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||||
|
|
||||||
|
sample_size = config.get("sample_size", None)
|
||||||
|
assert sample_size is not None, "Must specify sample size in config"
|
||||||
|
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Must specify sample rate in config"
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if model_type == 'DAU1d':
|
||||||
|
|
||||||
|
model = DiffusionAttnUnet1D(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "adp_uncond_1d":
|
||||||
|
|
||||||
|
model = UNet1DUncondWrapper(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "dit":
|
||||||
|
model = DiTUncondWrapper(
|
||||||
|
**diffusion_config
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
return DiffusionModelWrapper(model,
|
||||||
|
io_channels=model.io_channels,
|
||||||
|
sample_size=sample_size,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
pretransform=pretransform,
|
||||||
|
min_input_length=min_input_length)
|
||||||
|
|
||||||
|
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
diffusion_uncond_config = config["model"]
|
||||||
|
|
||||||
|
|
||||||
|
diffusion_config = diffusion_uncond_config.get('diffusion', {})
|
||||||
|
model_type = diffusion_config.get('type', None)
|
||||||
|
model_config = diffusion_config.get("config",{})
|
||||||
|
assert model_type is not None, "Must specify model type in config"
|
||||||
|
|
||||||
|
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||||
|
|
||||||
|
sample_size = config.get("sample_size", None)
|
||||||
|
assert sample_size is not None, "Must specify sample size in config"
|
||||||
|
|
||||||
|
sample_rate = config.get("sample_rate", None)
|
||||||
|
assert sample_rate is not None, "Must specify sample rate in config"
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if model_type == 'DAU1d':
|
||||||
|
|
||||||
|
model = DiffusionAttnUnet1D(
|
||||||
|
**model_config
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == "adp_uncond_1d":
|
||||||
|
|
||||||
|
io_channels = model_config.get("io_channels", 64)
|
||||||
|
model = UNet1DUncondWrapper(
|
||||||
|
io_channels = io_channels,
|
||||||
|
**model_config
|
||||||
|
)
|
||||||
|
elif model_type == "dit":
|
||||||
|
model = DiTUncondWrapper(
|
||||||
|
**model_config
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
return DiffusionModelWrapper(model,
|
||||||
|
io_channels=model.io_channels,
|
||||||
|
sample_size=sample_size,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
pretransform=pretransform,
|
||||||
|
min_input_length=min_input_length)
|
||||||
|
|
||||||
|
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||||
|
|
||||||
|
model_config = config["model"]
|
||||||
|
|
||||||
|
model_type = config["model_type"]
|
||||||
|
|
||||||
|
diffusion_config = model_config.get('diffusion', None)
|
||||||
|
assert diffusion_config is not None, "Must specify diffusion config"
|
||||||
|
|
||||||
|
diffusion_model_type = diffusion_config.get('type', None)
|
||||||
|
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||||
|
|
||||||
|
diffusion_model_config = diffusion_config.get('config', None)
|
||||||
|
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||||
|
|
||||||
|
if diffusion_model_type == 'adp_cfg_1d':
|
||||||
|
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'adp_1d':
|
||||||
|
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||||
|
elif diffusion_model_type == 'dit':
|
||||||
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||||
|
|
||||||
|
io_channels = model_config.get('io_channels', None)
|
||||||
|
assert io_channels is not None, "Must specify io_channels in model config"
|
||||||
|
|
||||||
|
sample_rate = config.get('sample_rate', None)
|
||||||
|
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||||
|
|
||||||
|
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||||
|
|
||||||
|
conditioning_config = model_config.get('conditioning', None)
|
||||||
|
|
||||||
|
conditioner = None
|
||||||
|
if conditioning_config is not None:
|
||||||
|
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
|
||||||
|
|
||||||
|
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||||
|
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||||
|
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||||
|
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||||
|
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||||
|
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||||
|
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||||
|
zero_init = diffusion_config.get('zero_init', False)
|
||||||
|
pretransform = model_config.get("pretransform", None)
|
||||||
|
|
||||||
|
if pretransform is not None:
|
||||||
|
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||||
|
min_input_length = pretransform.downsampling_ratio
|
||||||
|
else:
|
||||||
|
min_input_length = 1
|
||||||
|
|
||||||
|
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||||
|
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||||
|
elif diffusion_model_type == "dit":
|
||||||
|
min_input_length *= diffusion_model.model.patch_size
|
||||||
|
|
||||||
|
# Get the proper wrapper class
|
||||||
|
|
||||||
|
extra_kwargs = {}
|
||||||
|
|
||||||
|
if model_type == "mm_diffusion_cond":
|
||||||
|
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||||
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||||
|
|
||||||
|
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||||
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||||
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||||
|
|
||||||
|
return wrapper_fn(
|
||||||
|
diffusion_model,
|
||||||
|
conditioner,
|
||||||
|
min_input_length=min_input_length,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
cross_attn_cond_ids=cross_attention_ids,
|
||||||
|
global_cond_ids=global_cond_ids,
|
||||||
|
input_concat_ids=input_concat_ids,
|
||||||
|
prepend_cond_ids=prepend_cond_ids,
|
||||||
|
add_cond_ids=add_cond_ids,
|
||||||
|
sync_cond_ids=sync_cond_ids,
|
||||||
|
pretransform=pretransform,
|
||||||
|
io_channels=io_channels,
|
||||||
|
zero_init=zero_init,
|
||||||
|
**extra_kwargs
|
||||||
|
)
|
||||||
@@ -0,0 +1,539 @@
|
|||||||
|
import typing as tp
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
# from beartype.typing import Tuple
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||||
|
from .blocks import FourierFeatures
|
||||||
|
from .transformer import ContinuousTransformer
|
||||||
|
from .utils import mask_from_frac_lengths, resample
|
||||||
|
|
||||||
|
class DiffusionTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
io_channels=32,
|
||||||
|
patch_size=1,
|
||||||
|
embed_dim=768,
|
||||||
|
cond_token_dim=0,
|
||||||
|
project_cond_tokens=True,
|
||||||
|
global_cond_dim=0,
|
||||||
|
project_global_cond=True,
|
||||||
|
input_concat_dim=0,
|
||||||
|
prepend_cond_dim=0,
|
||||||
|
cond_ctx_dim=0,
|
||||||
|
depth=12,
|
||||||
|
num_heads=8,
|
||||||
|
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||||
|
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||||
|
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
|
||||||
|
add_token_dim=0,
|
||||||
|
sync_token_dim=0,
|
||||||
|
use_mlp=False,
|
||||||
|
use_zero_init=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_token_dim = cond_token_dim
|
||||||
|
|
||||||
|
# Timestep embeddings
|
||||||
|
timestep_features_dim = 256
|
||||||
|
# Timestep embeddings
|
||||||
|
self.timestep_cond_type = timestep_cond_type
|
||||||
|
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
||||||
|
|
||||||
|
if timestep_cond_type == "global":
|
||||||
|
timestep_embed_dim = embed_dim
|
||||||
|
elif timestep_cond_type == "input_concat":
|
||||||
|
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
|
||||||
|
input_concat_dim += timestep_embed_dim
|
||||||
|
|
||||||
|
self.to_timestep_embed = nn.Sequential(
|
||||||
|
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=True),
|
||||||
|
)
|
||||||
|
self.use_mlp = use_mlp
|
||||||
|
if cond_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_cond_embed = nn.Sequential(
|
||||||
|
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_embed_dim = 0
|
||||||
|
|
||||||
|
if global_cond_dim > 0:
|
||||||
|
# Global conditioning
|
||||||
|
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||||
|
self.to_global_embed = nn.Sequential(
|
||||||
|
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
if add_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_add_embed = nn.Sequential(
|
||||||
|
nn.Linear(add_token_dim, add_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
add_embed_dim = 0
|
||||||
|
|
||||||
|
if sync_token_dim > 0:
|
||||||
|
# Conditioning tokens
|
||||||
|
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
|
||||||
|
self.to_sync_embed = nn.Sequential(
|
||||||
|
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sync_embed_dim = 0
|
||||||
|
|
||||||
|
|
||||||
|
if prepend_cond_dim > 0:
|
||||||
|
# Prepend conditioning
|
||||||
|
self.to_prepend_embed = nn.Sequential(
|
||||||
|
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_concat_dim = input_concat_dim
|
||||||
|
|
||||||
|
dim_in = io_channels + self.input_concat_dim
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
# Transformer
|
||||||
|
|
||||||
|
self.transformer_type = transformer_type
|
||||||
|
|
||||||
|
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||||
|
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||||
|
self.global_cond_type = global_cond_type
|
||||||
|
if self.transformer_type == "continuous_transformer":
|
||||||
|
|
||||||
|
global_dim = None
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
# The global conditioning is projected to the embed_dim already at this point
|
||||||
|
global_dim = embed_dim
|
||||||
|
|
||||||
|
self.transformer = ContinuousTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
dim_heads=embed_dim // num_heads,
|
||||||
|
dim_in=dim_in * patch_size,
|
||||||
|
dim_out=io_channels * patch_size,
|
||||||
|
cross_attend = cond_token_dim > 0,
|
||||||
|
cond_token_dim = cond_embed_dim,
|
||||||
|
global_cond_dim=global_dim,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||||
|
|
||||||
|
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
||||||
|
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
||||||
|
nn.init.zeros_(self.preprocess_conv.weight)
|
||||||
|
nn.init.zeros_(self.postprocess_conv.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
# if isinstance(module, nn.Conv1d):
|
||||||
|
# if module.bias is not None:
|
||||||
|
# nn.init.constant_(module.bias, 0)
|
||||||
|
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
# Initialize timestep embedding MLP:
|
||||||
|
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
|
||||||
|
|
||||||
|
# Zero-out output layers:
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
for block in self.transformer.layers:
|
||||||
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||||
|
|
||||||
|
nn.init.constant_(self.empty_clip_feat, 0)
|
||||||
|
nn.init.constant_(self.empty_sync_feat, 0)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
mask=None,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
add_cond=None,
|
||||||
|
add_masks=None,
|
||||||
|
sync_cond=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
# Project the global conditioning to the embedding dimension
|
||||||
|
global_embed = self.to_global_embed(global_embed)
|
||||||
|
|
||||||
|
prepend_inputs = None
|
||||||
|
prepend_mask = None
|
||||||
|
prepend_length = 0
|
||||||
|
if prepend_cond is not None:
|
||||||
|
# Project the prepend conditioning to the embedding dimension
|
||||||
|
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||||
|
|
||||||
|
prepend_inputs = prepend_cond
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_mask = prepend_cond_mask
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
# reshape from (b, n, c) to (b, c, n)
|
||||||
|
if input_concat_cond.shape[1] != x.shape[1]:
|
||||||
|
input_concat_cond = input_concat_cond.transpose(1,2)
|
||||||
|
# Interpolate input_concat_cond to the same length as x
|
||||||
|
# if input_concat_cond.shape[1] != x.shape[2]:
|
||||||
|
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||||
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||||
|
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||||
|
# if len(global_embed.shape) == 2:
|
||||||
|
# global_embed = global_embed.unsqueeze(1)
|
||||||
|
# global_embed = global_embed + input_concat_cond
|
||||||
|
x = torch.cat([x, input_concat_cond], dim=1)
|
||||||
|
|
||||||
|
# Get the batch of timestep embeddings
|
||||||
|
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||||
|
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||||
|
if self.timestep_cond_type == "global":
|
||||||
|
if global_embed is not None:
|
||||||
|
if len(global_embed.shape) == 3:
|
||||||
|
timestep_embed = timestep_embed.unsqueeze(1)
|
||||||
|
global_embed = global_embed + timestep_embed
|
||||||
|
else:
|
||||||
|
global_embed = timestep_embed
|
||||||
|
elif self.timestep_cond_type == "input_concat":
|
||||||
|
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
|
||||||
|
|
||||||
|
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||||
|
if self.global_cond_type == "prepend" and global_embed is not None:
|
||||||
|
if prepend_inputs is None:
|
||||||
|
# Prepend inputs are just the global embed, and the mask is all ones
|
||||||
|
if len(global_embed.shape) == 2:
|
||||||
|
prepend_inputs = global_embed.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
prepend_inputs = global_embed
|
||||||
|
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||||
|
else:
|
||||||
|
# Prepend inputs are the prepend conditioning + the global embed
|
||||||
|
if len(global_embed.shape) == 2:
|
||||||
|
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||||
|
else:
|
||||||
|
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
|
||||||
|
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||||
|
|
||||||
|
prepend_length = prepend_inputs.shape[1]
|
||||||
|
|
||||||
|
x = self.preprocess_conv(x) + x
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
if self.global_cond_type == "adaLN":
|
||||||
|
extra_args["global_cond"] = global_embed
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
b, seq_len, c = x.shape
|
||||||
|
|
||||||
|
# 计算需要填充的数量
|
||||||
|
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
|
||||||
|
|
||||||
|
if pad_amount > 0:
|
||||||
|
# 在时间维度上进行填充
|
||||||
|
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
|
||||||
|
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||||
|
|
||||||
|
if add_cond is not None:
|
||||||
|
# Interpolate add_cond to the same length as x
|
||||||
|
# if self.use_mlp:
|
||||||
|
add_cond = self.to_add_embed(add_cond)
|
||||||
|
if add_cond.shape[1] != x.shape[1]:
|
||||||
|
add_cond = add_cond.transpose(1,2)
|
||||||
|
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
|
||||||
|
add_cond = add_cond.transpose(1,2)
|
||||||
|
# add_cond = resample(add_cond, x)
|
||||||
|
|
||||||
|
if sync_cond is not None:
|
||||||
|
sync_cond = self.to_sync_embed(sync_cond)
|
||||||
|
|
||||||
|
if self.transformer_type == "continuous_transformer":
|
||||||
|
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
output, info = output
|
||||||
|
|
||||||
|
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||||
|
|
||||||
|
if self.patch_size > 1:
|
||||||
|
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||||
|
# 移除之前添加的填充
|
||||||
|
if pad_amount > 0:
|
||||||
|
output = output[:, :, :seq_len]
|
||||||
|
|
||||||
|
output = self.postprocess_conv(output) + output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=None,
|
||||||
|
cross_attn_cond_mask=None,
|
||||||
|
negative_cross_attn_cond=None,
|
||||||
|
negative_cross_attn_mask=None,
|
||||||
|
input_concat_cond=None,
|
||||||
|
global_embed=None,
|
||||||
|
negative_global_embed=None,
|
||||||
|
prepend_cond=None,
|
||||||
|
prepend_cond_mask=None,
|
||||||
|
add_cond=None,
|
||||||
|
sync_cond=None,
|
||||||
|
cfg_scale=1.0,
|
||||||
|
cfg_dropout_prob=0.0,
|
||||||
|
causal=False,
|
||||||
|
scale_phi=0.0,
|
||||||
|
mask=None,
|
||||||
|
return_info=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
||||||
|
bsz, a, b = x.shape
|
||||||
|
model_dtype = next(self.parameters()).dtype
|
||||||
|
x = x.to(model_dtype)
|
||||||
|
t = t.to(model_dtype)
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
cross_attn_cond = cross_attn_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if negative_cross_attn_cond is not None:
|
||||||
|
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if input_concat_cond is not None:
|
||||||
|
input_concat_cond = input_concat_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if global_embed is not None:
|
||||||
|
global_embed = global_embed.to(model_dtype)
|
||||||
|
|
||||||
|
if negative_global_embed is not None:
|
||||||
|
negative_global_embed = negative_global_embed.to(model_dtype)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
prepend_cond = prepend_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if add_cond is not None:
|
||||||
|
add_cond = add_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if sync_cond is not None:
|
||||||
|
sync_cond = sync_cond.to(model_dtype)
|
||||||
|
|
||||||
|
if cross_attn_cond_mask is not None:
|
||||||
|
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
||||||
|
|
||||||
|
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
prepend_cond_mask = prepend_cond_mask.bool()
|
||||||
|
|
||||||
|
|
||||||
|
# CFG dropout
|
||||||
|
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||||
|
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||||
|
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||||
|
|
||||||
|
if add_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
|
||||||
|
add_cond = torch.where(dropout_mask, null_embed, add_cond)
|
||||||
|
|
||||||
|
if sync_cond is not None:
|
||||||
|
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||||
|
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
|
||||||
|
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
|
||||||
|
|
||||||
|
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
|
||||||
|
# Classifier-free guidance
|
||||||
|
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
||||||
|
batch_inputs = torch.cat([x, x], dim=0)
|
||||||
|
batch_timestep = torch.cat([t, t], dim=0)
|
||||||
|
if global_embed is not None and global_embed.shape[0] == bsz:
|
||||||
|
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
||||||
|
elif global_embed is not None:
|
||||||
|
batch_global_cond = global_embed
|
||||||
|
else:
|
||||||
|
batch_global_cond = None
|
||||||
|
|
||||||
|
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
|
||||||
|
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
||||||
|
elif input_concat_cond is not None:
|
||||||
|
batch_input_concat_cond = input_concat_cond
|
||||||
|
else:
|
||||||
|
batch_input_concat_cond = None
|
||||||
|
|
||||||
|
batch_cond = None
|
||||||
|
batch_cond_masks = None
|
||||||
|
|
||||||
|
# Handle CFG for cross-attention conditioning
|
||||||
|
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||||
|
|
||||||
|
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
||||||
|
if negative_cross_attn_cond is not None:
|
||||||
|
|
||||||
|
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
||||||
|
if negative_cross_attn_mask is not None:
|
||||||
|
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
||||||
|
|
||||||
|
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
||||||
|
|
||||||
|
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if cross_attn_cond_mask is not None:
|
||||||
|
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
||||||
|
elif cross_attn_cond is not None:
|
||||||
|
batch_cond = cross_attn_cond
|
||||||
|
else:
|
||||||
|
batch_cond = None
|
||||||
|
|
||||||
|
batch_prepend_cond = None
|
||||||
|
batch_prepend_cond_mask = None
|
||||||
|
|
||||||
|
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||||
|
|
||||||
|
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||||
|
|
||||||
|
if prepend_cond_mask is not None:
|
||||||
|
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||||
|
elif prepend_cond is not None:
|
||||||
|
batch_prepend_cond = prepend_cond
|
||||||
|
else:
|
||||||
|
batch_prepend_cond = None
|
||||||
|
|
||||||
|
batch_add_cond = None
|
||||||
|
|
||||||
|
# Handle CFG for cross-attention conditioning
|
||||||
|
if add_cond is not None and add_cond.shape[0] == bsz:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||||
|
|
||||||
|
|
||||||
|
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
|
||||||
|
elif add_cond is not None:
|
||||||
|
batch_add_cond = add_cond
|
||||||
|
else:
|
||||||
|
batch_add_cond = None
|
||||||
|
|
||||||
|
batch_sync_cond = None
|
||||||
|
|
||||||
|
# Handle CFG for cross-attention conditioning
|
||||||
|
if sync_cond is not None and sync_cond.shape[0] == bsz:
|
||||||
|
|
||||||
|
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||||
|
|
||||||
|
|
||||||
|
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
|
||||||
|
elif sync_cond is not None:
|
||||||
|
batch_sync_cond = sync_cond
|
||||||
|
else:
|
||||||
|
batch_sync_cond = None
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
batch_masks = torch.cat([mask, mask], dim=0)
|
||||||
|
else:
|
||||||
|
batch_masks = None
|
||||||
|
|
||||||
|
batch_output = self._forward(
|
||||||
|
batch_inputs,
|
||||||
|
batch_timestep,
|
||||||
|
cross_attn_cond=batch_cond,
|
||||||
|
cross_attn_cond_mask=batch_cond_masks,
|
||||||
|
mask = batch_masks,
|
||||||
|
input_concat_cond=batch_input_concat_cond,
|
||||||
|
global_embed = batch_global_cond,
|
||||||
|
prepend_cond = batch_prepend_cond,
|
||||||
|
prepend_cond_mask = batch_prepend_cond_mask,
|
||||||
|
add_cond = batch_add_cond,
|
||||||
|
sync_cond = batch_sync_cond,
|
||||||
|
return_info = return_info,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
batch_output, info = batch_output
|
||||||
|
|
||||||
|
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
||||||
|
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
||||||
|
|
||||||
|
# CFG Rescale
|
||||||
|
if scale_phi != 0.0:
|
||||||
|
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
||||||
|
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
||||||
|
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
||||||
|
else:
|
||||||
|
output = cfg_output
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return output, info
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
return self._forward(
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
cross_attn_cond=cross_attn_cond,
|
||||||
|
cross_attn_cond_mask=cross_attn_cond_mask,
|
||||||
|
input_concat_cond=input_concat_cond,
|
||||||
|
global_embed=global_embed,
|
||||||
|
prepend_cond=prepend_cond,
|
||||||
|
prepend_cond_mask=prepend_cond_mask,
|
||||||
|
add_cond=add_cond,
|
||||||
|
sync_cond=sync_cond,
|
||||||
|
mask=mask,
|
||||||
|
return_info=return_info,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from .blocks import AdaRMSNorm
|
||||||
|
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||||
|
from .utils import checkpoint
|
||||||
|
|
||||||
|
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||||
|
class ContinuousLocalTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
causal = False,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
heads = 8,
|
||||||
|
ff_mult = 2,
|
||||||
|
cond_dim = 0,
|
||||||
|
cross_attn_cond_dim = 0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
dim_head = dim//heads
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
||||||
|
|
||||||
|
self.local_attn_window_size = local_attn_window_size
|
||||||
|
|
||||||
|
self.cond_dim = cond_dim
|
||||||
|
|
||||||
|
self.cross_attn_cond_dim = cross_attn_cond_dim
|
||||||
|
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
||||||
|
|
||||||
|
for _ in range(depth):
|
||||||
|
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_heads=dim_head,
|
||||||
|
causal=causal,
|
||||||
|
zero_init_output=True,
|
||||||
|
natten_kernel_size=local_attn_window_size,
|
||||||
|
),
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_heads=dim_head,
|
||||||
|
dim_context = cross_attn_cond_dim,
|
||||||
|
zero_init_output=True
|
||||||
|
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
||||||
|
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||||
|
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
||||||
|
]))
|
||||||
|
|
||||||
|
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
||||||
|
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
if prepend_cond is not None:
|
||||||
|
x = torch.cat([prepend_cond, x], dim=1)
|
||||||
|
|
||||||
|
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||||
|
|
||||||
|
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if cond is not None:
|
||||||
|
x = checkpoint(attn_norm, x, cond)
|
||||||
|
else:
|
||||||
|
x = checkpoint(attn_norm, x)
|
||||||
|
|
||||||
|
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
||||||
|
|
||||||
|
if cross_attn_cond is not None:
|
||||||
|
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if cond is not None:
|
||||||
|
x = checkpoint(ff_norm, x, cond)
|
||||||
|
else:
|
||||||
|
x = checkpoint(ff_norm, x)
|
||||||
|
|
||||||
|
x = checkpoint(ff, x) + residual
|
||||||
|
|
||||||
|
return checkpoint(self.project_out, x)
|
||||||
|
|
||||||
|
class TransformerDownsampleBlock1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
embed_dim = 768,
|
||||||
|
depth = 3,
|
||||||
|
heads = 12,
|
||||||
|
downsample_ratio = 2,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
|
||||||
|
self.transformer = ContinuousLocalTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
local_attn_window_size=local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
# Compute
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
# Trade sequence length for channels
|
||||||
|
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
||||||
|
|
||||||
|
# Project back to embed dim
|
||||||
|
x = checkpoint(self.project_down, x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerUpsampleBlock1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
depth = 3,
|
||||||
|
heads = 12,
|
||||||
|
upsample_ratio = 2,
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.upsample_ratio = upsample_ratio
|
||||||
|
|
||||||
|
self.transformer = ContinuousLocalTransformer(
|
||||||
|
dim=embed_dim,
|
||||||
|
depth=depth,
|
||||||
|
heads=heads,
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||||
|
|
||||||
|
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
# Project to embed dim
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
|
||||||
|
# Project to increase channel dim
|
||||||
|
x = checkpoint(self.project_up, x)
|
||||||
|
|
||||||
|
# Trade channels for sequence length
|
||||||
|
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
||||||
|
|
||||||
|
# Compute
|
||||||
|
x = self.transformer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
embed_dims = [96, 192, 384, 768],
|
||||||
|
heads = [12, 12, 12, 12],
|
||||||
|
depths = [3, 3, 3, 3],
|
||||||
|
ratios = [2, 2, 2, 2],
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer in range(len(depths)):
|
||||||
|
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
TransformerDownsampleBlock1D(
|
||||||
|
in_channels = prev_dim,
|
||||||
|
embed_dim = embed_dims[layer],
|
||||||
|
heads = heads[layer],
|
||||||
|
depth = depths[layer],
|
||||||
|
downsample_ratio = ratios[layer],
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||||
|
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
x = checkpoint(self.project_out, x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder1D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
embed_dims = [768, 384, 192, 96],
|
||||||
|
heads = [12, 12, 12, 12],
|
||||||
|
depths = [3, 3, 3, 3],
|
||||||
|
ratios = [2, 2, 2, 2],
|
||||||
|
local_attn_window_size = 64,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
for layer in range(len(depths)):
|
||||||
|
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||||
|
|
||||||
|
layers.append(
|
||||||
|
TransformerUpsampleBlock1D(
|
||||||
|
in_channels = prev_dim,
|
||||||
|
embed_dim = embed_dims[layer],
|
||||||
|
heads = heads[layer],
|
||||||
|
depth = depths[layer],
|
||||||
|
upsample_ratio = ratios[layer],
|
||||||
|
local_attn_window_size = local_attn_window_size,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||||
|
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = rearrange(x, "b c n -> b n c")
|
||||||
|
x = checkpoint(self.project_in, x)
|
||||||
|
x = self.layers(x)
|
||||||
|
x = checkpoint(self.project_out, x)
|
||||||
|
x = rearrange(x, "b n c -> b c n")
|
||||||
|
return x
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# mmmodules package
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
# mmmodules.model package
|
||||||
@@ -0,0 +1,393 @@
|
|||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from scipy.optimize import fmin
|
||||||
|
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
||||||
|
|
||||||
|
class PQMF(nn.Module):
|
||||||
|
"""
|
||||||
|
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
||||||
|
Uses polyphase representation which is computationally more efficient for real-time.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
||||||
|
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, attenuation, num_bands):
|
||||||
|
super(PQMF, self).__init__()
|
||||||
|
|
||||||
|
# Ensure num_bands is a power of 2
|
||||||
|
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
||||||
|
assert is_power_of_2, "'num_bands' must be a power of 2."
|
||||||
|
|
||||||
|
# Create the prototype filter
|
||||||
|
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
||||||
|
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
||||||
|
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
||||||
|
|
||||||
|
# Register filters and settings
|
||||||
|
self.register_buffer("filter_bank", padded_filter_bank)
|
||||||
|
self.register_buffer("prototype", prototype_filter)
|
||||||
|
self.num_bands = num_bands
|
||||||
|
|
||||||
|
def forward(self, signal):
|
||||||
|
"""Decompose the signal into multiple frequency bands."""
|
||||||
|
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
||||||
|
signal = prepare_signal_dimensions(signal)
|
||||||
|
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
||||||
|
signal = pad_signal(signal, self.num_bands)
|
||||||
|
# run it
|
||||||
|
signal = polyphase_analysis(signal, self.filter_bank)
|
||||||
|
return apply_alias_cancellation(signal)
|
||||||
|
|
||||||
|
def inverse(self, bands):
|
||||||
|
"""Reconstruct the original signal from the frequency bands."""
|
||||||
|
bands = apply_alias_cancellation(bands)
|
||||||
|
return polyphase_synthesis(bands, self.filter_bank)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_signal_dimensions(signal):
|
||||||
|
"""
|
||||||
|
Rearrange signal into Batch x Channels x Length.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor or numpy.ndarray
|
||||||
|
The input signal.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Preprocessed signal tensor.
|
||||||
|
"""
|
||||||
|
# Convert numpy to torch tensor
|
||||||
|
if isinstance(signal, np.ndarray):
|
||||||
|
signal = torch.from_numpy(signal)
|
||||||
|
|
||||||
|
# Ensure tensor
|
||||||
|
if not isinstance(signal, torch.Tensor):
|
||||||
|
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
||||||
|
|
||||||
|
# Modify dimension of signal to Batch x Channels x Length
|
||||||
|
if signal.dim() == 1:
|
||||||
|
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
||||||
|
signal = signal.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif signal.dim() == 2:
|
||||||
|
# This is a multi-channel signal (e.g. stereo)
|
||||||
|
# Rearrange so that larger dimension (Length) is last
|
||||||
|
if signal.shape[0] > signal.shape[1]:
|
||||||
|
signal = signal.T
|
||||||
|
# Unsqueeze to 1 x Channels x Length
|
||||||
|
signal = signal.unsqueeze(0)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def pad_signal(signal, num_bands):
|
||||||
|
"""
|
||||||
|
Pads the signal to make its length divisible by the given number of bands.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
The input signal tensor, where the last dimension represents the signal length.
|
||||||
|
|
||||||
|
num_bands : int
|
||||||
|
The number of bands by which the signal length should be divisible.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The padded signal tensor. If the original signal length was already divisible
|
||||||
|
by num_bands, returns the original signal unchanged.
|
||||||
|
"""
|
||||||
|
remainder = signal.shape[-1] % num_bands
|
||||||
|
if remainder > 0:
|
||||||
|
padding_size = num_bands - remainder
|
||||||
|
signal = nn.functional.pad(signal, (0, padding_size))
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
||||||
|
"""
|
||||||
|
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prototype_filter : torch.Tensor
|
||||||
|
The prototype filter used as the basis for modulation.
|
||||||
|
num_bands : int
|
||||||
|
The number of desired subbands or filters.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
A bank of cosine modulated filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize indices for modulation.
|
||||||
|
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
||||||
|
|
||||||
|
# Calculate the length of the prototype filter.
|
||||||
|
filter_length = prototype_filter.shape[-1]
|
||||||
|
|
||||||
|
# Generate symmetric time indices centered around zero.
|
||||||
|
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
||||||
|
|
||||||
|
# Calculate phase offsets to ensure orthogonality between subbands.
|
||||||
|
phase_offsets = (-1)**subband_indices * np.pi / 4
|
||||||
|
|
||||||
|
# Compute the cosine modulation function.
|
||||||
|
modulation = torch.cos(
|
||||||
|
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply modulation to the prototype filter.
|
||||||
|
modulated_filters = 2 * prototype_filter * modulation
|
||||||
|
|
||||||
|
return modulated_filters
|
||||||
|
|
||||||
|
|
||||||
|
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
||||||
|
"""
|
||||||
|
Design a lowpass filter using the Kaiser window.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
angular_cutoff : float
|
||||||
|
The angular frequency cutoff of the filter.
|
||||||
|
attenuation : float
|
||||||
|
The desired stopband attenuation in decibels (dB).
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ndarray
|
||||||
|
The designed lowpass filter coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
||||||
|
|
||||||
|
# Ensure the estimated length is odd.
|
||||||
|
estimated_length = 2 * (estimated_length // 2) + 1
|
||||||
|
|
||||||
|
if filter_length is None:
|
||||||
|
filter_length = estimated_length
|
||||||
|
|
||||||
|
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
||||||
|
"""
|
||||||
|
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
angular_cutoff : float
|
||||||
|
Angular frequency cutoff of the filter.
|
||||||
|
attenuation : float
|
||||||
|
Desired stopband attenuation in dB.
|
||||||
|
num_bands : int
|
||||||
|
Number of bands for the multiband filter system.
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The computed objective (loss) value for the given filter specs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
||||||
|
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
||||||
|
|
||||||
|
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
||||||
|
"""
|
||||||
|
Design the optimal prototype filter for a multiband system given the desired specs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
attenuation : float
|
||||||
|
The desired stopband attenuation in dB.
|
||||||
|
num_bands : int
|
||||||
|
Number of bands for the multiband filter system.
|
||||||
|
filter_length : int, optional
|
||||||
|
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ndarray
|
||||||
|
The optimal prototype filter coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
||||||
|
1 / num_bands, disp=0)[0]
|
||||||
|
|
||||||
|
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
||||||
|
return torch.tensor(prototype_filter, dtype=torch.float32)
|
||||||
|
|
||||||
|
def pad_to_nearest_power_of_two(x):
|
||||||
|
"""
|
||||||
|
Pads the input tensor 'x' on both sides such that its last dimension
|
||||||
|
becomes the nearest larger power of two.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
x : torch.Tensor
|
||||||
|
The input tensor to be padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
The padded tensor.
|
||||||
|
"""
|
||||||
|
current_length = x.shape[-1]
|
||||||
|
target_length = 2**math.ceil(math.log2(current_length))
|
||||||
|
|
||||||
|
total_padding = target_length - current_length
|
||||||
|
left_padding = total_padding // 2
|
||||||
|
right_padding = total_padding - left_padding
|
||||||
|
|
||||||
|
return nn.functional.pad(x, (left_padding, right_padding))
|
||||||
|
|
||||||
|
def apply_alias_cancellation(x):
|
||||||
|
"""
|
||||||
|
Applies alias cancellation by inverting the sign of every
|
||||||
|
second element of every second row, starting from the second
|
||||||
|
row's first element in a tensor.
|
||||||
|
|
||||||
|
This operation helps ensure that the aliasing introduced in
|
||||||
|
each band during the decomposition will be counteracted during
|
||||||
|
the reconstruction.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
x : torch.Tensor
|
||||||
|
The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
Tensor with specific elements' sign inverted for alias cancellation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a mask of the same shape as 'x', initialized with all ones
|
||||||
|
mask = torch.ones_like(x)
|
||||||
|
|
||||||
|
# Update specific elements in the mask to -1 to perform inversion
|
||||||
|
mask[..., 1::2, ::2] = -1
|
||||||
|
|
||||||
|
# Apply the mask to the input tensor 'x'
|
||||||
|
return x * mask
|
||||||
|
|
||||||
|
def ensure_odd_length(tensor):
|
||||||
|
"""
|
||||||
|
Pads the last dimension of a tensor to ensure its size is odd.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor whose last dimension might need padding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
The original tensor if its last dimension was already odd,
|
||||||
|
or the padded tensor with an odd-sized last dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_dim_size = tensor.shape[-1]
|
||||||
|
|
||||||
|
if last_dim_size % 2 == 0:
|
||||||
|
tensor = nn.functional.pad(tensor, (0, 1))
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def polyphase_analysis(signal, filter_bank):
|
||||||
|
"""
|
||||||
|
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
Input signal tensor with shape (Batch x Channels x Length).
|
||||||
|
|
||||||
|
filter_bank : torch.Tensor
|
||||||
|
Filter bank tensor with shape (Bands x Length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_bands = filter_bank.shape[0]
|
||||||
|
num_channels = signal.shape[1]
|
||||||
|
|
||||||
|
# Rearrange signal for polyphase processing.
|
||||||
|
# Also combine Batch x Channel into one dimension for now.
|
||||||
|
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
||||||
|
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
||||||
|
|
||||||
|
# Rearrange the filter bank for matching signal shape
|
||||||
|
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
||||||
|
|
||||||
|
# Apply convolution with appropriate padding to maintain spatial dimensions
|
||||||
|
padding = filter_bank.shape[-1] // 2
|
||||||
|
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
||||||
|
|
||||||
|
# Truncate the last dimension post-convolution to adjust the output shape
|
||||||
|
filtered_signal = filtered_signal[..., :-1]
|
||||||
|
# Rearrange the first dimension back into Batch x Channels
|
||||||
|
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
||||||
|
|
||||||
|
return filtered_signal
|
||||||
|
|
||||||
|
def polyphase_synthesis(signal, filter_bank):
|
||||||
|
"""
|
||||||
|
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal : torch.Tensor
|
||||||
|
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
||||||
|
|
||||||
|
filter_bank : torch.Tensor
|
||||||
|
Analysis filter bank (shape: Bands x Length).
|
||||||
|
|
||||||
|
should_rearrange : bool, optional
|
||||||
|
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Reconstructed signal (shape: Batch x Channels X Length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_bands = filter_bank.shape[0]
|
||||||
|
num_channels = signal.shape[1]
|
||||||
|
|
||||||
|
# Rearrange the filter bank
|
||||||
|
filter_bank = filter_bank.flip(-1)
|
||||||
|
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
||||||
|
|
||||||
|
# Combine Batch x Channels into one dimension for now.
|
||||||
|
signal = rearrange(signal, "b c n t -> (b c) n t")
|
||||||
|
|
||||||
|
# Apply convolution with appropriate padding
|
||||||
|
padding_amount = filter_bank.shape[-1] // 2 + 1
|
||||||
|
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
||||||
|
|
||||||
|
# Scale the result
|
||||||
|
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
||||||
|
|
||||||
|
# Reorganize the output and truncate
|
||||||
|
reconstructed_signal = reconstructed_signal.flip(1)
|
||||||
|
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
||||||
|
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
||||||
|
|
||||||
|
return reconstructed_signal
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Pretransform(nn.Module):
|
||||||
|
def __init__(self, enable_grad, io_channels, is_discrete):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.is_discrete = is_discrete
|
||||||
|
self.io_channels = io_channels
|
||||||
|
self.encoded_channels = None
|
||||||
|
self.downsampling_ratio = None
|
||||||
|
|
||||||
|
self.enable_grad = enable_grad
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class AutoencoderPretransform(Pretransform):
|
||||||
|
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
||||||
|
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
||||||
|
self.model = model
|
||||||
|
self.model.requires_grad_(False).eval()
|
||||||
|
self.scale=scale
|
||||||
|
self.downsampling_ratio = model.downsampling_ratio
|
||||||
|
self.io_channels = model.io_channels
|
||||||
|
self.sample_rate = model.sample_rate
|
||||||
|
|
||||||
|
self.model_half = model_half
|
||||||
|
self.iterate_batch = iterate_batch
|
||||||
|
|
||||||
|
self.encoded_channels = model.latent_dim
|
||||||
|
|
||||||
|
self.chunked = chunked
|
||||||
|
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||||
|
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
self.model.half()
|
||||||
|
|
||||||
|
def encode(self, x, **kwargs):
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
x = x.half()
|
||||||
|
self.model.to(torch.float16)
|
||||||
|
|
||||||
|
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
encoded = encoded.float()
|
||||||
|
|
||||||
|
return encoded / self.scale
|
||||||
|
|
||||||
|
def decode(self, z, **kwargs):
|
||||||
|
z = z * self.scale
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
z = z.half()
|
||||||
|
self.model.to(torch.float16)
|
||||||
|
|
||||||
|
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||||
|
|
||||||
|
if self.model_half:
|
||||||
|
decoded = decoded.float()
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
|
def tokenize(self, x, **kwargs):
|
||||||
|
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
||||||
|
|
||||||
|
_, info = self.model.encode(x, return_info = True, **kwargs)
|
||||||
|
|
||||||
|
return info[self.model.bottleneck.tokens_id]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens, **kwargs):
|
||||||
|
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
||||||
|
|
||||||
|
return self.model.decode_tokens(tokens, **kwargs)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
self.model.load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
class PQMFPretransform(Pretransform):
|
||||||
|
def __init__(self, attenuation=100, num_bands=16):
|
||||||
|
# TODO: Fix PQMF to take in in-channels
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
||||||
|
from .pqmf import PQMF
|
||||||
|
self.pqmf = PQMF(attenuation, num_bands)
|
||||||
|
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
# x is (Batch x Channels x Time)
|
||||||
|
x = self.pqmf.forward(x)
|
||||||
|
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
||||||
|
# but Pretransform needs Batch x Channels x Time
|
||||||
|
# so concatenate channels and bands into one axis
|
||||||
|
return rearrange(x, "b c n t -> b (c n) t")
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
||||||
|
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
||||||
|
# returns (Batch x Channels x Time)
|
||||||
|
return self.pqmf.inverse(x)
|
||||||
|
|
||||||
|
class PretrainedDACPretransform(Pretransform):
|
||||||
|
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||||
|
|
||||||
|
import dac
|
||||||
|
|
||||||
|
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
||||||
|
|
||||||
|
self.model = dac.DAC.load(model_path)
|
||||||
|
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
if model_type == "44khz":
|
||||||
|
self.downsampling_ratio = 512
|
||||||
|
else:
|
||||||
|
self.downsampling_ratio = 320
|
||||||
|
|
||||||
|
self.io_channels = 1
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.chunked = chunked
|
||||||
|
|
||||||
|
self.encoded_channels = self.model.latent_dim
|
||||||
|
|
||||||
|
self.num_quantizers = self.model.n_codebooks
|
||||||
|
|
||||||
|
self.codebook_size = self.model.codebook_size
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
|
||||||
|
latents = self.model.encoder(x)
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
output = latents
|
||||||
|
else:
|
||||||
|
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||||
|
output = z
|
||||||
|
|
||||||
|
if self.scale != 1.0:
|
||||||
|
output = output / self.scale
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
|
||||||
|
if self.scale != 1.0:
|
||||||
|
z = z * self.scale
|
||||||
|
|
||||||
|
if self.quantize_on_decode:
|
||||||
|
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||||
|
|
||||||
|
return self.model.decode(z)
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
return self.model.encode(x)[1]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
latents = self.model.quantizer.from_codes(tokens)
|
||||||
|
return self.model.decode(latents)
|
||||||
|
|
||||||
|
class AudiocraftCompressionPretransform(Pretransform):
|
||||||
|
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
||||||
|
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import CompressionModel
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
||||||
|
|
||||||
|
self.model = CompressionModel.get_pretrained(model_type)
|
||||||
|
|
||||||
|
self.quantize_on_decode = quantize_on_decode
|
||||||
|
|
||||||
|
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
||||||
|
|
||||||
|
self.sample_rate = self.model.sample_rate
|
||||||
|
|
||||||
|
self.io_channels = self.model.channels
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
#self.encoded_channels = self.model.latent_dim
|
||||||
|
|
||||||
|
self.num_quantizers = self.model.num_codebooks
|
||||||
|
|
||||||
|
self.codebook_size = self.model.cardinality
|
||||||
|
|
||||||
|
self.model.to(torch.float16).eval().requires_grad_(False)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
|
||||||
|
assert False, "Audiocraft compression models do not support continuous encoding"
|
||||||
|
|
||||||
|
# latents = self.model.encoder(x)
|
||||||
|
|
||||||
|
# if self.quantize_on_decode:
|
||||||
|
# output = latents
|
||||||
|
# else:
|
||||||
|
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||||
|
# output = z
|
||||||
|
|
||||||
|
# if self.scale != 1.0:
|
||||||
|
# output = output / self.scale
|
||||||
|
|
||||||
|
# return output
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
|
||||||
|
assert False, "Audiocraft compression models do not support continuous decoding"
|
||||||
|
|
||||||
|
# if self.scale != 1.0:
|
||||||
|
# z = z * self.scale
|
||||||
|
|
||||||
|
# if self.quantize_on_decode:
|
||||||
|
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||||
|
|
||||||
|
# return self.model.decode(z)
|
||||||
|
|
||||||
|
def tokenize(self, x):
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return self.model.encode(x.to(torch.float16))[0]
|
||||||
|
|
||||||
|
def decode_tokens(self, tokens):
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
return self.model.decode(tokens)
|
||||||
@@ -0,0 +1,989 @@
|
|||||||
|
from functools import reduce, partial
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, einsum
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||||
|
from typing import Callable, Literal
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
flash_attn_kvpacked_func = None
|
||||||
|
flash_attn_func = None
|
||||||
|
|
||||||
|
from .utils import compile, checkpoint
|
||||||
|
try:
|
||||||
|
import natten
|
||||||
|
except ImportError:
|
||||||
|
natten = None
|
||||||
|
|
||||||
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
||||||
|
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
||||||
|
|
||||||
|
def create_causal_mask(i, j, device):
|
||||||
|
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||||
|
|
||||||
|
def or_reduce(masks):
|
||||||
|
head, *body = masks
|
||||||
|
for rest in body:
|
||||||
|
head = head | rest
|
||||||
|
return head
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
|
||||||
|
class AbsolutePositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_seq_len):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** -0.5
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.emb = nn.Embedding(max_seq_len, dim)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||||
|
|
||||||
|
pos_emb = self.emb(pos)
|
||||||
|
pos_emb = pos_emb * self.scale
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
class ScaledSinusoidalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, theta = 10000):
|
||||||
|
super().__init__()
|
||||||
|
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||||
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||||
|
|
||||||
|
half_dim = dim // 2
|
||||||
|
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||||
|
inv_freq = theta ** -freq_seq
|
||||||
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||||
|
|
||||||
|
def forward(self, x, pos = None, seq_start_pos = None):
|
||||||
|
seq_len, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
if pos is None:
|
||||||
|
pos = torch.arange(seq_len, device = device)
|
||||||
|
|
||||||
|
if seq_start_pos is not None:
|
||||||
|
pos = pos - seq_start_pos[..., None]
|
||||||
|
|
||||||
|
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||||
|
return emb * self.scale
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
use_xpos = False,
|
||||||
|
scale_base = 512,
|
||||||
|
interpolation_factor = 1.,
|
||||||
|
base = 10000,
|
||||||
|
base_rescale_factor = 1.
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||||
|
|
||||||
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer('inv_freq', inv_freq)
|
||||||
|
|
||||||
|
assert interpolation_factor >= 1.
|
||||||
|
self.interpolation_factor = interpolation_factor
|
||||||
|
|
||||||
|
if not use_xpos:
|
||||||
|
self.register_buffer('scale', None)
|
||||||
|
return
|
||||||
|
|
||||||
|
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||||
|
|
||||||
|
self.scale_base = scale_base
|
||||||
|
self.register_buffer('scale', scale)
|
||||||
|
|
||||||
|
def forward_from_seq_len(self, seq_len):
|
||||||
|
device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = torch.arange(seq_len, device = device)
|
||||||
|
return self.forward(t)
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
|
def forward(self, t):
|
||||||
|
device = self.inv_freq.device
|
||||||
|
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
|
||||||
|
t = t / self.interpolation_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||||
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||||
|
|
||||||
|
if self.scale is None:
|
||||||
|
return freqs, 1.
|
||||||
|
|
||||||
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||||
|
scale = self.scale ** rearrange(power, 'n -> n 1')
|
||||||
|
scale = torch.cat((scale, scale), dim = -1)
|
||||||
|
|
||||||
|
return freqs, scale
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||||
|
x1, x2 = x.unbind(dim = -2)
|
||||||
|
return torch.cat((-x2, x1), dim = -1)
|
||||||
|
|
||||||
|
@autocast(enabled = False)
|
||||||
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||||
|
out_dtype = t.dtype
|
||||||
|
|
||||||
|
# cast to float32 if necessary for numerical stability
|
||||||
|
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||||
|
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||||
|
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||||
|
freqs = freqs[-seq_len:, :]
|
||||||
|
|
||||||
|
if t.ndim == 4 and freqs.ndim == 3:
|
||||||
|
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||||
|
|
||||||
|
# partial rotary embeddings, Wang et al. GPT-J
|
||||||
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||||
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||||
|
|
||||||
|
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||||
|
|
||||||
|
return torch.cat((t, t_unrotated), dim = -1)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
class DynamicTanh(nn.Module):
|
||||||
|
def __init__(self, dim, init_alpha=10.0):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.tanh(self.alpha * x)
|
||||||
|
return self.gamma * x + self.beta
|
||||||
|
|
||||||
|
class RunningInstanceNorm(nn.Module):
|
||||||
|
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
||||||
|
self.register_buffer("running_std", torch.ones(1,1,dim))
|
||||||
|
self.saturate = saturate
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
self.dim = dim
|
||||||
|
self.trainable_gain = trainable_gain
|
||||||
|
if self.trainable_gain:
|
||||||
|
self.gain = nn.Parameter(torch.ones(1))
|
||||||
|
|
||||||
|
def _update_stats(self, x):
|
||||||
|
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
||||||
|
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training:
|
||||||
|
self._update_stats(x)
|
||||||
|
x = (x - self.running_mean) / self.running_std
|
||||||
|
if self.saturate:
|
||||||
|
x = torch.asinh(x)
|
||||||
|
if self.trainable_gain:
|
||||||
|
x = x * self.gain
|
||||||
|
return x
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
||||||
|
"""
|
||||||
|
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if fix_scale:
|
||||||
|
self.register_buffer("gamma", torch.ones(dim))
|
||||||
|
else:
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.beta = nn.Parameter(torch.zeros(dim))
|
||||||
|
else:
|
||||||
|
self.register_buffer("beta", torch.zeros(dim))
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.force_fp32 = force_fp32
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.force_fp32:
|
||||||
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
||||||
|
else:
|
||||||
|
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
||||||
|
return output.to(x.dtype)
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(self, dim, init_val = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.full([dim], init_val))
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.scale
|
||||||
|
|
||||||
|
class GLU(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim_in,
|
||||||
|
dim_out,
|
||||||
|
activation: Callable,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
||||||
|
self.use_conv = use_conv
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.proj(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
else:
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
x, gate = x.chunk(2, dim = -1)
|
||||||
|
return x * self.act(gate)
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_out = None,
|
||||||
|
mult = 4,
|
||||||
|
no_bias = False,
|
||||||
|
glu = True,
|
||||||
|
use_conv = False,
|
||||||
|
conv_kernel_size = 3,
|
||||||
|
zero_init_output = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
|
||||||
|
# Default to SwiGLU
|
||||||
|
|
||||||
|
activation = nn.SiLU()
|
||||||
|
|
||||||
|
dim_out = dim if dim_out is None else dim_out
|
||||||
|
|
||||||
|
if glu:
|
||||||
|
linear_in = GLU(dim, inner_dim, activation)
|
||||||
|
else:
|
||||||
|
linear_in = nn.Sequential(
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
activation
|
||||||
|
)
|
||||||
|
|
||||||
|
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
||||||
|
|
||||||
|
# init last linear layer to 0
|
||||||
|
if zero_init_output:
|
||||||
|
nn.init.zeros_(linear_out.weight)
|
||||||
|
if not no_bias:
|
||||||
|
nn.init.zeros_(linear_out.bias)
|
||||||
|
|
||||||
|
|
||||||
|
self.ff = nn.Sequential(
|
||||||
|
linear_in,
|
||||||
|
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||||
|
linear_out,
|
||||||
|
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
dim_context = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_output=True,
|
||||||
|
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
||||||
|
differential = False,
|
||||||
|
feat_scale = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = dim_heads
|
||||||
|
self.differential = differential
|
||||||
|
|
||||||
|
dim_kv = dim_context if dim_context is not None else dim
|
||||||
|
|
||||||
|
self.num_heads = dim // dim_heads
|
||||||
|
self.kv_heads = dim_kv // dim_heads
|
||||||
|
|
||||||
|
if dim_context is not None:
|
||||||
|
if differential:
|
||||||
|
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
||||||
|
else:
|
||||||
|
self.to_q = nn.Linear(dim, dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
||||||
|
else:
|
||||||
|
if differential:
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
||||||
|
else:
|
||||||
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(dim, dim, bias=False)
|
||||||
|
|
||||||
|
if zero_init_output:
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
|
||||||
|
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
||||||
|
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
||||||
|
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
if self.qk_norm == "ln":
|
||||||
|
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||||
|
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||||
|
elif self.qk_norm == 'rns':
|
||||||
|
self.q_norm = nn.RMSNorm(dim_heads)
|
||||||
|
self.k_norm = nn.RMSNorm(dim_heads)
|
||||||
|
elif self.qk_norm == 'dyt':
|
||||||
|
self.q_norm = DynamicTanh(dim_heads)
|
||||||
|
self.k_norm = DynamicTanh(dim_heads)
|
||||||
|
|
||||||
|
self.sdp_kwargs = dict(
|
||||||
|
enable_flash = True,
|
||||||
|
enable_math = True,
|
||||||
|
enable_mem_efficient = True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.feat_scale = feat_scale
|
||||||
|
|
||||||
|
if self.feat_scale:
|
||||||
|
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
||||||
|
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def apply_qk_layernorm(self, q, k):
|
||||||
|
q_type = q.dtype
|
||||||
|
k_type = k.dtype
|
||||||
|
q = self.q_norm(q).to(q_type)
|
||||||
|
k = self.k_norm(k).to(k_type)
|
||||||
|
return q, k
|
||||||
|
|
||||||
|
|
||||||
|
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
||||||
|
|
||||||
|
if self.num_heads != self.kv_heads:
|
||||||
|
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
||||||
|
heads_per_kv_head = self.num_heads // self.kv_heads
|
||||||
|
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||||
|
|
||||||
|
flash_attn_available = HAS_FLASH_ATTN
|
||||||
|
|
||||||
|
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||||
|
flex_attention_block_mask = None
|
||||||
|
flex_attention_score_mod = None
|
||||||
|
|
||||||
|
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FlexAttention is not available in this build. "
|
||||||
|
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
||||||
|
)
|
||||||
|
elif flash_attn_available:
|
||||||
|
fa_dtype_in = q.dtype
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||||
|
|
||||||
|
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
||||||
|
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
||||||
|
|
||||||
|
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
||||||
|
|
||||||
|
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
||||||
|
else:
|
||||||
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
#@compile
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
causal = None,
|
||||||
|
flex_attention_block_mask = None,
|
||||||
|
flex_attention_score_mod = None,
|
||||||
|
flash_attn_sliding_window = None
|
||||||
|
):
|
||||||
|
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||||
|
|
||||||
|
kv_input = context if has_context else x
|
||||||
|
|
||||||
|
if hasattr(self, 'to_q'):
|
||||||
|
# Use separate linear projections for q and k/v
|
||||||
|
if self.differential:
|
||||||
|
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
||||||
|
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
||||||
|
q = torch.stack([q, q_diff], dim = 1)
|
||||||
|
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
||||||
|
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
||||||
|
k = torch.stack([k, k_diff], dim = 1)
|
||||||
|
else:
|
||||||
|
q = self.to_q(x)
|
||||||
|
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||||
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||||
|
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||||
|
else:
|
||||||
|
# Use fused linear projection
|
||||||
|
if self.differential:
|
||||||
|
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
||||||
|
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
||||||
|
q = torch.stack([q, q_diff], dim = 1)
|
||||||
|
k = torch.stack([k, k_diff], dim = 1)
|
||||||
|
else:
|
||||||
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||||
|
|
||||||
|
# Normalize q and k for cosine sim attention
|
||||||
|
if self.qk_norm == "l2":
|
||||||
|
q = F.normalize(q, dim=-1)
|
||||||
|
k = F.normalize(k, dim=-1)
|
||||||
|
elif self.qk_norm != "none":
|
||||||
|
q, k = self.apply_qk_layernorm(q, k)
|
||||||
|
|
||||||
|
if rotary_pos_emb is not None:
|
||||||
|
freqs, _ = rotary_pos_emb
|
||||||
|
q_dtype = q.dtype
|
||||||
|
k_dtype = k.dtype
|
||||||
|
q = q.to(torch.float32)
|
||||||
|
k = k.to(torch.float32)
|
||||||
|
freqs = freqs.to(torch.float32)
|
||||||
|
if q.shape[-2] >= k.shape[-2]:
|
||||||
|
ratio = q.shape[-2] / k.shape[-2]
|
||||||
|
q_freqs, k_freqs = freqs, ratio * freqs
|
||||||
|
else:
|
||||||
|
ratio = k.shape[-2] / q.shape[-2]
|
||||||
|
q_freqs, k_freqs = ratio * freqs, freqs
|
||||||
|
q = apply_rotary_pos_emb(q, q_freqs)
|
||||||
|
k = apply_rotary_pos_emb(k, k_freqs)
|
||||||
|
q = q.to(v.dtype)
|
||||||
|
k = k.to(v.dtype)
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
|
|
||||||
|
if n == 1 and causal:
|
||||||
|
causal = False
|
||||||
|
|
||||||
|
if self.differential:
|
||||||
|
q, q_diff = q.unbind(dim = 1)
|
||||||
|
k, k_diff = k.unbind(dim = 1)
|
||||||
|
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||||
|
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||||
|
out = out - out_diff
|
||||||
|
else:
|
||||||
|
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||||
|
|
||||||
|
# merge heads
|
||||||
|
out = rearrange(out, ' b h n d -> b n (h d)')
|
||||||
|
|
||||||
|
# Communicate between heads
|
||||||
|
|
||||||
|
# with autocast(enabled = False):
|
||||||
|
# out_dtype = out.dtype
|
||||||
|
# out = out.to(torch.float32)
|
||||||
|
# out = self.to_out(out).to(out_dtype)
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
if self.feat_scale:
|
||||||
|
out_dc = out.mean(dim=-2, keepdim=True)
|
||||||
|
out_hf = out - out_dc
|
||||||
|
|
||||||
|
# Selectively modulate DC and high frequency components
|
||||||
|
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ConformerModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
norm_kwargs = {},
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||||
|
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
self.glu = GLU(dim, dim, nn.SiLU())
|
||||||
|
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||||
|
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||||
|
self.swish = nn.SiLU()
|
||||||
|
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
#@compile
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.in_norm(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.glu(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
x = self.mid_norm(x)
|
||||||
|
x = self.swish(x)
|
||||||
|
x = rearrange(x, 'b n d -> b d n')
|
||||||
|
x = self.pointwise_conv_2(x)
|
||||||
|
x = rearrange(x, 'b d n -> b n d')
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend = False,
|
||||||
|
dim_context = None,
|
||||||
|
global_cond_dim = None,
|
||||||
|
causal = False,
|
||||||
|
zero_init_branch_outputs = True,
|
||||||
|
conformer = False,
|
||||||
|
layer_ix = -1,
|
||||||
|
remove_norms = False,
|
||||||
|
add_rope = False,
|
||||||
|
layer_scale = False,
|
||||||
|
use_sync_block_film = False,
|
||||||
|
attn_kwargs = {},
|
||||||
|
ff_kwargs = {},
|
||||||
|
norm_kwargs = {}
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.dim_heads = min(dim_heads,dim)
|
||||||
|
self.cross_attend = cross_attend
|
||||||
|
self.dim_context = dim_context
|
||||||
|
self.causal = causal
|
||||||
|
if layer_scale and zero_init_branch_outputs:
|
||||||
|
zero_init_branch_outputs = False
|
||||||
|
|
||||||
|
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||||
|
|
||||||
|
self.add_rope = add_rope
|
||||||
|
|
||||||
|
self.self_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = self.dim_heads,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||||
|
|
||||||
|
self.cross_attend = cross_attend
|
||||||
|
if cross_attend:
|
||||||
|
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||||
|
self.cross_attn = Attention(
|
||||||
|
dim,
|
||||||
|
dim_heads = self.dim_heads,
|
||||||
|
dim_context=dim_context,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_output=zero_init_branch_outputs,
|
||||||
|
**attn_kwargs
|
||||||
|
)
|
||||||
|
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||||
|
|
||||||
|
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||||
|
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
||||||
|
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||||
|
|
||||||
|
self.layer_ix = layer_ix
|
||||||
|
|
||||||
|
self.conformer = None
|
||||||
|
if conformer:
|
||||||
|
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
||||||
|
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||||
|
|
||||||
|
self.global_cond_dim = global_cond_dim
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
||||||
|
|
||||||
|
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
||||||
|
|
||||||
|
if use_sync_block_film:
|
||||||
|
self.sync_film_generator = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||||
|
)
|
||||||
|
|
||||||
|
@compile
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
context = None,
|
||||||
|
global_cond=None,
|
||||||
|
rotary_pos_emb = None,
|
||||||
|
self_attention_block_mask = None,
|
||||||
|
self_attention_score_mod = None,
|
||||||
|
cross_attention_block_mask = None,
|
||||||
|
cross_attention_score_mod = None,
|
||||||
|
self_attention_flash_sliding_window = None,
|
||||||
|
cross_attention_flash_sliding_window = None,
|
||||||
|
sync_cond = None,
|
||||||
|
prepend_length=0
|
||||||
|
):
|
||||||
|
if rotary_pos_emb is None and self.add_rope:
|
||||||
|
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
||||||
|
|
||||||
|
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||||
|
if len(global_cond.shape) == 2:
|
||||||
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
||||||
|
else:
|
||||||
|
|
||||||
|
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
||||||
|
|
||||||
|
# self-attention with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
x = x * (1 + scale_self) + shift_self
|
||||||
|
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
||||||
|
x = x * torch.sigmoid(1 - gate_self)
|
||||||
|
x = self.self_attn_scale(x)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
if context is not None and self.cross_attend:
|
||||||
|
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer_scale(self.conformer(x))
|
||||||
|
|
||||||
|
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||||
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||||
|
x = x * (1 + scale) + shift
|
||||||
|
|
||||||
|
# feedforward with adaLN
|
||||||
|
residual = x
|
||||||
|
x = self.ff_norm(x)
|
||||||
|
x = x * (1 + scale_ff) + shift_ff
|
||||||
|
x = self.ff(x)
|
||||||
|
x = x * torch.sigmoid(1 - gate_ff)
|
||||||
|
x = self.ff_scale(x)
|
||||||
|
x = x + residual
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
||||||
|
|
||||||
|
if context is not None and self.cross_attend:
|
||||||
|
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||||
|
|
||||||
|
if self.conformer is not None:
|
||||||
|
x = x + self.conformer_scale(self.conformer(x))
|
||||||
|
|
||||||
|
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||||
|
prepend_part = x[:, :prepend_length, :]
|
||||||
|
audio_part = x[:, prepend_length:, :]
|
||||||
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||||
|
modulated_audio_part = audio_part * (1 + scale) + shift
|
||||||
|
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
||||||
|
|
||||||
|
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class ContinuousTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
*,
|
||||||
|
dim_in = None,
|
||||||
|
dim_out = None,
|
||||||
|
dim_heads = 64,
|
||||||
|
cross_attend=False,
|
||||||
|
cond_token_dim=None,
|
||||||
|
pre_cross_attn_ix=-1,
|
||||||
|
final_cross_attn_ix=-1,
|
||||||
|
global_cond_dim=None,
|
||||||
|
causal=False,
|
||||||
|
rotary_pos_emb=True,
|
||||||
|
zero_init_branch_outputs=True,
|
||||||
|
conformer=False,
|
||||||
|
use_sinusoidal_emb=False,
|
||||||
|
use_abs_pos_emb=False,
|
||||||
|
abs_pos_emb_max_length=10000,
|
||||||
|
num_memory_tokens=0,
|
||||||
|
sliding_window=None,
|
||||||
|
use_mlp=False,
|
||||||
|
use_add_norm=False,
|
||||||
|
use_gated=False,
|
||||||
|
use_final_layer=False,
|
||||||
|
use_zeros=False,
|
||||||
|
use_conv=False,
|
||||||
|
use_fusion_mlp=False,
|
||||||
|
use_film=False,
|
||||||
|
use_sync_film=False,
|
||||||
|
use_sync_gated=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
self.causal = causal
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
if use_mlp:
|
||||||
|
self.project_in = nn.Sequential(
|
||||||
|
nn.Linear(dim_in, dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim, bias=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
||||||
|
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
||||||
|
self.video_temporal_conv = None
|
||||||
|
self.audio_temporal_conv = None
|
||||||
|
self.fusion_mlp = None
|
||||||
|
if use_conv:
|
||||||
|
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||||
|
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||||
|
if use_fusion_mlp:
|
||||||
|
self.fusion_mlp = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
if rotary_pos_emb:
|
||||||
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||||
|
else:
|
||||||
|
self.rotary_pos_emb = None
|
||||||
|
self.num_memory_tokens = num_memory_tokens
|
||||||
|
if num_memory_tokens > 0:
|
||||||
|
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||||
|
|
||||||
|
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||||
|
if use_sinusoidal_emb:
|
||||||
|
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||||
|
|
||||||
|
self.use_abs_pos_emb = use_abs_pos_emb
|
||||||
|
if use_abs_pos_emb:
|
||||||
|
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
||||||
|
|
||||||
|
self.adaLN_modulation = None
|
||||||
|
if global_cond_dim is not None:
|
||||||
|
if use_final_layer:
|
||||||
|
self.norm_final = LayerNorm(dim)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
dim, 2 * dim, bias=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_zeros:
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||||
|
nn.init.constant_(self.project_out.weight, 0)
|
||||||
|
self.global_cond_embedder = nn.Sequential(
|
||||||
|
nn.Linear(global_cond_dim, dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 6)
|
||||||
|
)
|
||||||
|
if use_zeros:
|
||||||
|
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
||||||
|
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
||||||
|
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
||||||
|
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
||||||
|
|
||||||
|
self.final_cross_attn_ix = final_cross_attn_ix
|
||||||
|
self.use_gated = use_gated
|
||||||
|
self.use_film = use_film
|
||||||
|
self.use_add_norm = use_add_norm
|
||||||
|
if self.use_add_norm:
|
||||||
|
self.add_norm = nn.LayerNorm(dim)
|
||||||
|
if use_gated:
|
||||||
|
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
||||||
|
|
||||||
|
if use_film:
|
||||||
|
self.film_generator = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.film_generator = None
|
||||||
|
|
||||||
|
if use_sync_film:
|
||||||
|
self.sync_film_generator = nn.Sequential(
|
||||||
|
nn.Linear(dim, dim, bias=False),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.sync_film_generator = None
|
||||||
|
if use_sync_gated:
|
||||||
|
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
else:
|
||||||
|
self.sync_gate = None
|
||||||
|
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
||||||
|
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
||||||
|
self.layers.append(
|
||||||
|
TransformerBlock(
|
||||||
|
dim,
|
||||||
|
dim_heads = dim_heads,
|
||||||
|
cross_attend = should_cross_attend,
|
||||||
|
dim_context = cond_token_dim,
|
||||||
|
global_cond_dim = global_cond_dim,
|
||||||
|
causal = causal,
|
||||||
|
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||||
|
conformer=conformer,
|
||||||
|
layer_ix=i,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
mask = None,
|
||||||
|
prepend_embeds = None,
|
||||||
|
prepend_mask = None,
|
||||||
|
add_cond = None,
|
||||||
|
sync_cond = None,
|
||||||
|
global_cond = None,
|
||||||
|
return_info = False,
|
||||||
|
use_checkpointing = True,
|
||||||
|
exit_layer_ix = None,
|
||||||
|
video_dropout_prob = 0.0,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
batch, seq, device = *x.shape[:2], x.device
|
||||||
|
model_dtype = next(self.parameters()).dtype
|
||||||
|
x = x.to(model_dtype)
|
||||||
|
|
||||||
|
prepend_length = 0
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"hidden_states": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
x = self.project_in(x)
|
||||||
|
if add_cond is not None:
|
||||||
|
if self.use_gated:
|
||||||
|
gate = torch.sigmoid(self.gate)
|
||||||
|
x = x + gate * add_cond
|
||||||
|
elif self.use_film:
|
||||||
|
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
||||||
|
x = x * (1 + scale) + shift
|
||||||
|
else:
|
||||||
|
x = x + add_cond
|
||||||
|
|
||||||
|
if self.use_add_norm:
|
||||||
|
x = self.add_norm(x)
|
||||||
|
if self.fusion_mlp is not None:
|
||||||
|
x = self.fusion_mlp(x)
|
||||||
|
|
||||||
|
if sync_cond is not None:
|
||||||
|
# Resample sync_cond to match audio sequence length if needed
|
||||||
|
if sync_cond.shape[1] != x.shape[1]:
|
||||||
|
sync_cond = torch.nn.functional.interpolate(
|
||||||
|
sync_cond.transpose(1, 2), size=x.shape[1],
|
||||||
|
mode='linear', align_corners=False,
|
||||||
|
).transpose(1, 2)
|
||||||
|
if self.sync_film_generator is not None:
|
||||||
|
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||||
|
x = x * (1 + scale) + shift
|
||||||
|
elif self.sync_gate is not None:
|
||||||
|
gate_value = torch.sigmoid(self.sync_gate)
|
||||||
|
x = x + gate_value * sync_cond
|
||||||
|
# else:
|
||||||
|
# x = x + sync_cond
|
||||||
|
|
||||||
|
if prepend_embeds is not None:
|
||||||
|
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||||
|
|
||||||
|
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||||
|
|
||||||
|
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||||
|
|
||||||
|
if self.num_memory_tokens > 0:
|
||||||
|
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
||||||
|
x = torch.cat((memory_tokens, x), dim=1)
|
||||||
|
|
||||||
|
if self.rotary_pos_emb is not None:
|
||||||
|
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||||
|
else:
|
||||||
|
rotary_pos_emb = None
|
||||||
|
|
||||||
|
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||||
|
x = x + self.pos_emb(x)
|
||||||
|
|
||||||
|
if global_cond is not None and self.global_cond_embedder is not None:
|
||||||
|
global_cond_embed = self.global_cond_embedder(global_cond)
|
||||||
|
else:
|
||||||
|
global_cond_embed = global_cond
|
||||||
|
# Iterate over the transformer layers
|
||||||
|
for layer_ix, layer in enumerate(self.layers):
|
||||||
|
if use_checkpointing:
|
||||||
|
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||||
|
else:
|
||||||
|
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
info["hidden_states"].append(x)
|
||||||
|
|
||||||
|
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
||||||
|
x = x[:, self.num_memory_tokens:, :]
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = x[:, self.num_memory_tokens:, :]
|
||||||
|
if global_cond is not None and self.adaLN_modulation is not None:
|
||||||
|
if len(global_cond.shape) == 2:
|
||||||
|
global_cond = global_cond.unsqueeze(1)
|
||||||
|
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.project_out(x)
|
||||||
|
|
||||||
|
if return_info:
|
||||||
|
return x, info
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
|
||||||
|
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
|
||||||
|
from torch.nn.utils import remove_weight_norm
|
||||||
|
|
||||||
|
def load_ckpt_state_dict(ckpt_path, prefix=None):
|
||||||
|
if ckpt_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_file(ckpt_path)
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||||
|
|
||||||
|
# 过滤特定前缀的state_dict
|
||||||
|
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
|
||||||
|
|
||||||
|
return filtered_state_dict
|
||||||
|
|
||||||
|
def remove_weight_norm_from_model(model):
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, "weight"):
|
||||||
|
remove_weight_norm(module)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
||||||
|
# License can be found in LICENSES/LICENSE_META.txt
|
||||||
|
|
||||||
|
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||||
|
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (torch.Tensor): The input tensor containing probabilities.
|
||||||
|
num_samples (int): Number of samples to draw.
|
||||||
|
replacement (bool): Whether to draw with replacement or not.
|
||||||
|
Keywords args:
|
||||||
|
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Last dimension contains num_samples indices
|
||||||
|
sampled from the multinomial probability distribution
|
||||||
|
located in the last dimension of tensor input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if num_samples == 1:
|
||||||
|
q = torch.empty_like(input).exponential_(1, generator=generator)
|
||||||
|
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
||||||
|
|
||||||
|
input_ = input.reshape(-1, input.shape[-1])
|
||||||
|
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||||
|
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||||
|
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||||
|
k (int): The k in “top-k”.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled tokens.
|
||||||
|
"""
|
||||||
|
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||||
|
min_value_top_k = top_k_value[..., [-1]]
|
||||||
|
probs *= (probs >= min_value_top_k).float()
|
||||||
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = multinomial(probs, num_samples=1)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||||
|
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||||
|
p (int): The p in “top-p”.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort *= (~mask).float()
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
def next_power_of_two(n):
|
||||||
|
return 2 ** (n - 1).bit_length()
|
||||||
|
|
||||||
|
def next_multiple_of_64(n):
|
||||||
|
return ((n + 63) // 64) * 64
|
||||||
|
|
||||||
|
|
||||||
|
# mask construction helpers
|
||||||
|
|
||||||
|
def mask_from_start_end_indices(
|
||||||
|
seq_len: int,
|
||||||
|
start: Tensor,
|
||||||
|
end: Tensor
|
||||||
|
):
|
||||||
|
assert start.shape == end.shape
|
||||||
|
device = start.device
|
||||||
|
|
||||||
|
seq = torch.arange(seq_len, device = device, dtype = torch.long)
|
||||||
|
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
|
||||||
|
seq = seq.expand(*start.shape, seq_len)
|
||||||
|
|
||||||
|
mask = seq >= start[..., None].long()
|
||||||
|
mask &= seq < end[..., None].long()
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def mask_from_frac_lengths(
|
||||||
|
seq_len: int,
|
||||||
|
frac_lengths: Tensor
|
||||||
|
):
|
||||||
|
device = frac_lengths.device
|
||||||
|
|
||||||
|
lengths = (frac_lengths * seq_len).long()
|
||||||
|
max_start = seq_len - lengths
|
||||||
|
|
||||||
|
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
|
||||||
|
start = (max_start * rand).clamp(min = 0)
|
||||||
|
end = start + lengths
|
||||||
|
|
||||||
|
return mask_from_start_end_indices(seq_len, start, end)
|
||||||
|
|
||||||
|
def _build_spline(video_feat, video_t, target_t):
|
||||||
|
# 三次样条插值核心实现
|
||||||
|
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
|
||||||
|
spline = NaturalCubicSpline(coeffs)
|
||||||
|
return spline.evaluate(target_t).permute(0,2,1)
|
||||||
|
|
||||||
|
def resample(video_feat, audio_latent):
|
||||||
|
"""
|
||||||
|
9s
|
||||||
|
video_feat: [B, 72, D]
|
||||||
|
audio_latent: [B, D', 194] or int
|
||||||
|
"""
|
||||||
|
B, Tv, D = video_feat.shape
|
||||||
|
|
||||||
|
if isinstance(audio_latent, torch.Tensor):
|
||||||
|
# audio_latent is a tensor
|
||||||
|
if audio_latent.shape[1] != 64:
|
||||||
|
Ta = audio_latent.shape[1]
|
||||||
|
else:
|
||||||
|
Ta = audio_latent.shape[2]
|
||||||
|
elif isinstance(audio_latent, int):
|
||||||
|
# audio_latent is an int
|
||||||
|
Ta = audio_latent
|
||||||
|
else:
|
||||||
|
raise TypeError("audio_latent must be either a tensor or an int")
|
||||||
|
|
||||||
|
# 构建时间戳 (关键改进点)
|
||||||
|
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
|
||||||
|
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
|
||||||
|
|
||||||
|
# 三维化处理 (Batch, Feature, Time)
|
||||||
|
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
|
||||||
|
|
||||||
|
# 三次样条插值
|
||||||
|
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
||||||
|
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
||||||
|
|
||||||
|
def checkpoint(function, *args, **kwargs):
|
||||||
|
kwargs.setdefault("use_reentrant", False)
|
||||||
|
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||||
|
|
||||||
|
import os
|
||||||
|
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
||||||
|
|
||||||
|
def compile(function, *args, **kwargs):
|
||||||
|
|
||||||
|
if enable_torch_compile:
|
||||||
|
try:
|
||||||
|
return torch.compile(function, *args, **kwargs)
|
||||||
|
except RuntimeError:
|
||||||
|
return function
|
||||||
|
|
||||||
|
return function
|
||||||
@@ -1,5 +1,12 @@
|
|||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
|
einops-exts
|
||||||
|
safetensors
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
transformers>=4.52.3
|
transformers>=4.52.3
|
||||||
|
k-diffusion>=0.1.1
|
||||||
|
alias-free-torch
|
||||||
|
descript-audio-codec
|
||||||
|
vector-quantize-pytorch
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
torchaudio
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
name: prismaudio-extract
|
||||||
|
channels:
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.10
|
||||||
|
- pip
|
||||||
|
- ffmpeg<7
|
||||||
|
- pip:
|
||||||
|
- torch>=2.6.0
|
||||||
|
- torchaudio>=2.6.0
|
||||||
|
- torchvision>=0.21.0
|
||||||
|
- tensorflow-cpu==2.15.0
|
||||||
|
- jax
|
||||||
|
- jaxlib
|
||||||
|
- transformers>=4.52.3
|
||||||
|
- decord
|
||||||
|
- einops>=0.7.0
|
||||||
|
- numpy
|
||||||
|
- mediapy
|
||||||
|
- git+https://github.com/google-deepmind/videoprism.git
|
||||||
Executable
+170
@@ -0,0 +1,170 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Standalone PrismAudio feature extraction script.
|
||||||
|
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
|
||||||
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
|
||||||
|
if _PLUGIN_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, _PLUGIN_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def _step(n, total, label):
|
||||||
|
"""Print step header and return start time."""
|
||||||
|
print(f"[extract] Step {n}/{total} — {label}...", flush=True)
|
||||||
|
return time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def _done(t0, extra=""):
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
suffix = f" {extra}" if extra else ""
|
||||||
|
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
t_total = time.perf_counter()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
||||||
|
parser.add_argument("--video", required=True, help="Path to input video")
|
||||||
|
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
||||||
|
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||||
|
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||||
|
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||||
|
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
||||||
|
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||||
|
parser.add_argument("--clip_size", type=int, default=288)
|
||||||
|
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||||
|
parser.add_argument("--sync_size", type=int, default=224)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"[extract] Python : {sys.executable}", flush=True)
|
||||||
|
print(f"[extract] Video : {args.video}", flush=True)
|
||||||
|
print(f"[extract] Output : {args.output}", flush=True)
|
||||||
|
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
|
||||||
|
|
||||||
|
if not os.path.exists(args.video):
|
||||||
|
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(1, 6, "importing dependencies")
|
||||||
|
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||||
|
import torchvision.transforms as T
|
||||||
|
_done(t0)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
|
||||||
|
feat_utils = FeaturesUtils(
|
||||||
|
vae_config_path=args.vae_config,
|
||||||
|
synchformer_ckpt=args.synchformer_ckpt,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
_done(t0)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(3, 6, "reading and preprocessing video")
|
||||||
|
if args.video.endswith(".npy"):
|
||||||
|
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
||||||
|
fps = args.source_fps
|
||||||
|
total_frames = all_frames.shape[0]
|
||||||
|
duration = total_frames / fps
|
||||||
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
|
clip_frames = all_frames[clip_indices]
|
||||||
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
|
sync_frames = all_frames[sync_indices]
|
||||||
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
else:
|
||||||
|
from decord import VideoReader, cpu
|
||||||
|
vr = VideoReader(args.video, ctx=cpu(0))
|
||||||
|
fps = vr.get_avg_fps()
|
||||||
|
total_frames = len(vr)
|
||||||
|
duration = total_frames / fps
|
||||||
|
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||||
|
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||||
|
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||||
|
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||||
|
|
||||||
|
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||||
|
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||||
|
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||||
|
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||||
|
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||||
|
|
||||||
|
clip_transform = T.Compose([
|
||||||
|
T.ToPILImage(),
|
||||||
|
T.Resize(args.clip_size),
|
||||||
|
T.CenterCrop(args.clip_size),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
sync_transform = T.Compose([
|
||||||
|
T.ToPILImage(),
|
||||||
|
T.Resize(args.sync_size),
|
||||||
|
T.CenterCrop(args.sync_size),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
||||||
|
_done(t0)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(4, 6, "encoding text with T5-Gemma")
|
||||||
|
text_features = feat_utils.encode_t5_text([args.cot_text])
|
||||||
|
_done(t0, f"shape={tuple(text_features.shape)}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(5, 6, "encoding video with VideoPrism")
|
||||||
|
global_video_features, video_features, global_text_features = \
|
||||||
|
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
||||||
|
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = _step(6, 6, "encoding video with Synchformer")
|
||||||
|
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
||||||
|
_done(t0, f"shape={tuple(sync_features.shape)}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
print(f"[extract] Saving features to {args.output} ...", flush=True)
|
||||||
|
np.savez(
|
||||||
|
args.output,
|
||||||
|
video_features=video_features.cpu().float().numpy(),
|
||||||
|
global_video_features=global_video_features.cpu().float().numpy(),
|
||||||
|
text_features=text_features.cpu().float().numpy(),
|
||||||
|
global_text_features=global_text_features.cpu().float().numpy(),
|
||||||
|
sync_features=sync_features.cpu().float().numpy(),
|
||||||
|
caption_cot=args.cot_text,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||||
|
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Executable
+44
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Install the PrismAudio feature-extraction environment using pip venv.
|
||||||
|
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/install_extract_env.sh [/path/to/venv]
|
||||||
|
#
|
||||||
|
# Default venv path: /opt/prismaudio-extract
|
||||||
|
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
|
||||||
|
# <venv>/bin/python (Linux/Mac)
|
||||||
|
# <venv>\Scripts\python.exe (Windows)
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
VENV_DIR="${1:-/opt/prismaudio-extract}"
|
||||||
|
|
||||||
|
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
|
||||||
|
python3 -m venv "${VENV_DIR}"
|
||||||
|
|
||||||
|
PIP="${VENV_DIR}/bin/pip"
|
||||||
|
|
||||||
|
echo "[PrismAudio] Upgrading pip..."
|
||||||
|
"${PIP}" install --upgrade pip
|
||||||
|
|
||||||
|
echo "[PrismAudio] Installing PyTorch stack..."
|
||||||
|
"${PIP}" install torch torchaudio torchvision
|
||||||
|
|
||||||
|
echo "[PrismAudio] Installing feature-extraction dependencies..."
|
||||||
|
"${PIP}" install \
|
||||||
|
"tensorflow-cpu>=2.16.0" \
|
||||||
|
"jax[cpu]" \
|
||||||
|
"jaxlib" \
|
||||||
|
"transformers" \
|
||||||
|
"decord" \
|
||||||
|
"einops" \
|
||||||
|
"numpy" \
|
||||||
|
"mediapy"
|
||||||
|
|
||||||
|
echo "[PrismAudio] Installing VideoPrism..."
|
||||||
|
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
|
||||||
|
echo " ${VENV_DIR}/bin/python"
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# Vendored from https://github.com/jnwnlee/selva
|
|
||||||
# Pinned commit: d7d40a992aab58e7cf246055681a657e5d8b4a4d
|
|
||||||
# Imports rewritten from selva.* → selva_core.*
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from fractions import Fraction
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from av import AudioFrame
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VideoInfo:
|
|
||||||
duration_sec: float
|
|
||||||
fps: Fraction
|
|
||||||
clip_frames: torch.Tensor
|
|
||||||
sync_frames: torch.Tensor
|
|
||||||
all_frames: Optional[list[np.ndarray]]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def height(self):
|
|
||||||
return self.all_frames[0].shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def width(self):
|
|
||||||
return self.all_frames[0].shape[1]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
|
||||||
fps: Fraction) -> 'VideoInfo':
|
|
||||||
num_frames = int(duration_sec * fps)
|
|
||||||
all_frames = [image_info.original_frame] * num_frames
|
|
||||||
return cls(duration_sec=duration_sec,
|
|
||||||
fps=fps,
|
|
||||||
clip_frames=image_info.clip_frames,
|
|
||||||
sync_frames=image_info.sync_frames,
|
|
||||||
all_frames=all_frames)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ImageInfo:
|
|
||||||
clip_frames: torch.Tensor
|
|
||||||
sync_frames: torch.Tensor
|
|
||||||
original_frame: Optional[np.ndarray]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def height(self):
|
|
||||||
return self.original_frame.shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def width(self):
|
|
||||||
return self.original_frame.shape[1]
|
|
||||||
|
|
||||||
|
|
||||||
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
|
||||||
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
|
||||||
output_frames = [[] for _ in list_of_fps]
|
|
||||||
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
|
||||||
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
|
||||||
all_frames = []
|
|
||||||
|
|
||||||
# container = av.open(video_path)
|
|
||||||
with av.open(video_path) as container:
|
|
||||||
stream = container.streams.video[0]
|
|
||||||
fps = stream.guessed_rate
|
|
||||||
stream.thread_type = 'AUTO'
|
|
||||||
for packet in container.demux(stream):
|
|
||||||
for frame in packet.decode():
|
|
||||||
frame_time = frame.time
|
|
||||||
if frame_time < start_sec:
|
|
||||||
continue
|
|
||||||
if frame_time > end_sec:
|
|
||||||
break
|
|
||||||
|
|
||||||
frame_np = None
|
|
||||||
if need_all_frames:
|
|
||||||
frame_np = frame.to_ndarray(format='rgb24')
|
|
||||||
all_frames.append(frame_np)
|
|
||||||
|
|
||||||
for i, _ in enumerate(list_of_fps):
|
|
||||||
this_time = frame_time
|
|
||||||
while this_time >= next_frame_time_for_each_fps[i]:
|
|
||||||
if frame_np is None:
|
|
||||||
frame_np = frame.to_ndarray(format='rgb24')
|
|
||||||
|
|
||||||
output_frames[i].append(frame_np)
|
|
||||||
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
|
||||||
|
|
||||||
output_frames = [np.stack(frames) for frames in output_frames]
|
|
||||||
return output_frames, all_frames, fps
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_video_chunk(video_chunk: torch.Tensor,
|
|
||||||
expected_length: int,
|
|
||||||
*,
|
|
||||||
n_tolerance_frame: int = 1,
|
|
||||||
desc: str = "") \
|
|
||||||
-> torch.Tensor:
|
|
||||||
# video_chunk: [T, H, W, C]
|
|
||||||
if video_chunk.shape[0] < expected_length:
|
|
||||||
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
|
||||||
# copy the last frame to make it the right length
|
|
||||||
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
|
||||||
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
|
||||||
assert video_chunk.shape[0] == expected_length
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
|
||||||
)
|
|
||||||
video_chunk = video_chunk[:expected_length]
|
|
||||||
if video_chunk.shape[0] != expected_length:
|
|
||||||
raise RuntimeError(f'Video wrong length {desc}, '
|
|
||||||
f'expected {expected_length}, '
|
|
||||||
f'got {video_chunk.shape[0]}')
|
|
||||||
|
|
||||||
return video_chunk
|
|
||||||
|
|
||||||
|
|
||||||
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
|
||||||
sampling_rate: int):
|
|
||||||
container = av.open(output_path, 'w')
|
|
||||||
output_video_stream = container.add_stream('h264', video_info.fps)
|
|
||||||
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
|
||||||
output_video_stream.width = video_info.width
|
|
||||||
output_video_stream.height = video_info.height
|
|
||||||
output_video_stream.pix_fmt = 'yuv420p'
|
|
||||||
|
|
||||||
output_audio_stream = container.add_stream('aac', sampling_rate)
|
|
||||||
|
|
||||||
# encode video
|
|
||||||
for image in video_info.all_frames:
|
|
||||||
image = av.VideoFrame.from_ndarray(image)
|
|
||||||
packet = output_video_stream.encode(image)
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
for packet in output_video_stream.encode():
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
# convert float tensor audio to numpy array
|
|
||||||
audio_np = audio.numpy().astype(np.float32)
|
|
||||||
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
|
||||||
audio_frame.sample_rate = sampling_rate
|
|
||||||
|
|
||||||
for packet in output_audio_stream.encode(audio_frame):
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
for packet in output_audio_stream.encode():
|
|
||||||
container.mux(packet)
|
|
||||||
|
|
||||||
container.close()
|
|
||||||
|
|
||||||
|
|
||||||
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
|
||||||
"""
|
|
||||||
NOTE: I don't think we can get the exact video duration right without re-encoding
|
|
||||||
so we are not using this but keeping it here for reference
|
|
||||||
"""
|
|
||||||
video = av.open(video_path)
|
|
||||||
output = av.open(output_path, 'w')
|
|
||||||
input_video_stream = video.streams.video[0]
|
|
||||||
output_video_stream = output.add_stream(template=input_video_stream)
|
|
||||||
output_audio_stream = output.add_stream('aac', sampling_rate)
|
|
||||||
|
|
||||||
duration_sec = audio.shape[-1] / sampling_rate
|
|
||||||
|
|
||||||
for packet in video.demux(input_video_stream):
|
|
||||||
# We need to skip the "flushing" packets that `demux` generates.
|
|
||||||
if packet.dts is None:
|
|
||||||
continue
|
|
||||||
# We need to assign the packet to the new stream.
|
|
||||||
packet.stream = output_video_stream
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# convert float tensor audio to numpy array
|
|
||||||
audio_np = audio.numpy().astype(np.float32)
|
|
||||||
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
|
||||||
audio_frame.sample_rate = sampling_rate
|
|
||||||
|
|
||||||
for packet in output_audio_stream.encode(audio_frame):
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
for packet in output_audio_stream.encode():
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
video.close()
|
|
||||||
output.close()
|
|
||||||
|
|
||||||
output.close()
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
import logging
|
|
||||||
import random
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from omegaconf import DictConfig, open_dict
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from torch.utils.data.dataloader import default_collate
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from selva_core.data.vgg_sound import VGGSound
|
|
||||||
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
|
|
||||||
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
|
|
||||||
from selva_core.data.eval.audiocaps import AudioCapsData
|
|
||||||
from selva_core.data.mm_dataset import MultiModalDataset
|
|
||||||
from selva_core.data.mixup import DataMixupCollate
|
|
||||||
from selva_core.utils.dist_utils import local_rank
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
# Re-seed randomness every time we start a worker
|
|
||||||
def worker_init_fn(worker_id: int):
|
|
||||||
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
|
||||||
np.random.seed(worker_seed)
|
|
||||||
random.seed(worker_seed)
|
|
||||||
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
|
||||||
|
|
||||||
|
|
||||||
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
|
|
||||||
) -> Dataset:
|
|
||||||
dataset = VGGSound(root=data_cfg.root,
|
|
||||||
tsv_path=data_cfg.subset_name,
|
|
||||||
sample_rate=16_000,
|
|
||||||
duration_sec=8.0,
|
|
||||||
normalize_audio=normalize_audio,
|
|
||||||
mmap_dir=data_cfg.memmap_dir,
|
|
||||||
tsv_tsynch_path=data_cfg.tsv_tsynch,
|
|
||||||
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
|
|
||||||
data_dim=cfg.data_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
|
||||||
raise NotImplementedError('Audio data loading is not implemented yet')
|
|
||||||
|
|
||||||
|
|
||||||
def setup_training_datasets(cfg: DictConfig,
|
|
||||||
generator: torch.Generator,
|
|
||||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
|
||||||
if cfg.mini_train:
|
|
||||||
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
|
|
||||||
dataset = MultiModalDataset([vgg], [])
|
|
||||||
if cfg.example_train:
|
|
||||||
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
|
|
||||||
dataset = MultiModalDataset([video], [])
|
|
||||||
else:
|
|
||||||
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
|
|
||||||
# load the largest one first
|
|
||||||
# you can add more video/audio data upon demand, such as
|
|
||||||
# clotho = load_audio_data(cfg, cfg.data.Clotho)
|
|
||||||
dataset = MultiModalDataset([vgg], [])
|
|
||||||
|
|
||||||
batch_size = cfg.batch_size
|
|
||||||
num_workers = cfg.num_workers
|
|
||||||
pin_memory = cfg.pin_memory
|
|
||||||
|
|
||||||
if cfg.mixup.domain == 'data':
|
|
||||||
mixup_params = cfg.mixup.params
|
|
||||||
collate_fn = DataMixupCollate(generator=generator,
|
|
||||||
**mixup_params)
|
|
||||||
else:
|
|
||||||
collate_fn = None
|
|
||||||
|
|
||||||
sampler, loader = construct_loader(dataset,
|
|
||||||
batch_size,
|
|
||||||
num_workers,
|
|
||||||
shuffle=True,
|
|
||||||
drop_last=True,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
|
|
||||||
return dataset, sampler, loader
|
|
||||||
|
|
||||||
|
|
||||||
def setup_test_datasets(cfg: DictConfig,
|
|
||||||
generator: torch.Generator,
|
|
||||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
|
||||||
if cfg.example_train:
|
|
||||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
|
|
||||||
elif cfg.dataset.startswith('vggsound'):
|
|
||||||
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
|
|
||||||
|
|
||||||
batch_size = cfg.batch_size
|
|
||||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
|
||||||
pin_memory = cfg.pin_memory
|
|
||||||
|
|
||||||
if cfg.mixup.domain == 'data':
|
|
||||||
mixup_config = cfg.mixup.params
|
|
||||||
collate_fn = DataMixupCollate(generator=generator,
|
|
||||||
**mixup_config)
|
|
||||||
else:
|
|
||||||
collate_fn = None
|
|
||||||
|
|
||||||
sampler, loader = construct_loader(dataset,
|
|
||||||
batch_size,
|
|
||||||
num_workers,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
|
|
||||||
return dataset, sampler, loader
|
|
||||||
|
|
||||||
|
|
||||||
def setup_val_datasets(cfg: DictConfig,
|
|
||||||
generator: torch.Generator,
|
|
||||||
) -> tuple[Dataset, DataLoader, DataLoader]:
|
|
||||||
if cfg.example_train:
|
|
||||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
|
|
||||||
else:
|
|
||||||
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
|
|
||||||
|
|
||||||
val_batch_size = cfg.batch_size
|
|
||||||
val_eval_batch_size = cfg.eval_batch_size
|
|
||||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
|
||||||
pin_memory = cfg.pin_memory
|
|
||||||
|
|
||||||
if cfg.mixup.domain == 'data':
|
|
||||||
mixup_config = cfg.mixup.params
|
|
||||||
collate_fn = DataMixupCollate(generator=generator,
|
|
||||||
**mixup_config)
|
|
||||||
else:
|
|
||||||
collate_fn = None
|
|
||||||
|
|
||||||
_, val_loader = construct_loader(dataset,
|
|
||||||
val_batch_size,
|
|
||||||
num_workers,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
_, eval_loader = construct_loader(dataset,
|
|
||||||
val_eval_batch_size,
|
|
||||||
num_workers,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=collate_fn)
|
|
||||||
|
|
||||||
return dataset, val_loader, eval_loader
|
|
||||||
|
|
||||||
|
|
||||||
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
|
||||||
if dataset_name.startswith('audiocaps_full'):
|
|
||||||
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
|
|
||||||
cfg.eval_data.audiocaps_full.csv_path)
|
|
||||||
elif dataset_name.startswith('audiocaps'):
|
|
||||||
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
|
|
||||||
cfg.eval_data.audiocaps.csv_path)
|
|
||||||
elif dataset_name.startswith('vggsound'):
|
|
||||||
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
|
|
||||||
cfg.eval_data.vggsound.csv_path,
|
|
||||||
duration_sec=cfg.duration_s)
|
|
||||||
elif dataset_name.startswith('infer_video'):
|
|
||||||
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
|
|
||||||
cfg.eval_data.infer_video.jsonl_path,
|
|
||||||
duration_sec=cfg.duration_s)
|
|
||||||
cfg.batch_size = 1
|
|
||||||
elif dataset_name.startswith('example_video'):
|
|
||||||
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
|
|
||||||
cfg.eval_data.Example_video.csv_path,
|
|
||||||
duration_sec=cfg.duration_s)
|
|
||||||
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
|
|
||||||
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
|
|
||||||
cfg.eval_data[dataset_name].csv_path,
|
|
||||||
duration_sec=cfg.duration_s)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
|
||||||
|
|
||||||
batch_size = cfg.batch_size
|
|
||||||
num_workers = cfg.num_workers
|
|
||||||
pin_memory = cfg.pin_memory
|
|
||||||
_, loader = construct_loader(dataset,
|
|
||||||
batch_size,
|
|
||||||
num_workers,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
error_avoidance=True)
|
|
||||||
return dataset, loader
|
|
||||||
|
|
||||||
|
|
||||||
def error_avoidance_collate(batch):
|
|
||||||
# Filter our None values
|
|
||||||
batch = [item for item in batch if item is not None]
|
|
||||||
if len(batch) == 0:
|
|
||||||
return None
|
|
||||||
return default_collate(batch)
|
|
||||||
|
|
||||||
|
|
||||||
def construct_loader(dataset: Dataset,
|
|
||||||
batch_size: int,
|
|
||||||
num_workers: int,
|
|
||||||
*,
|
|
||||||
shuffle: bool = True,
|
|
||||||
drop_last: bool = True,
|
|
||||||
pin_memory: bool = False,
|
|
||||||
error_avoidance: bool = False,
|
|
||||||
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
|
|
||||||
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
|
||||||
train_loader = DataLoader(dataset,
|
|
||||||
batch_size,
|
|
||||||
sampler=train_sampler,
|
|
||||||
num_workers=num_workers,
|
|
||||||
worker_init_fn=worker_init_fn,
|
|
||||||
drop_last=drop_last,
|
|
||||||
persistent_workers=num_workers > 0,
|
|
||||||
pin_memory=pin_memory,
|
|
||||||
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
|
|
||||||
return train_sampler, train_loader
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class AudioCapsData(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
|
||||||
df = pd.read_csv(csv_path).to_dict(orient='records')
|
|
||||||
|
|
||||||
audio_files = sorted(os.listdir(audio_path))
|
|
||||||
audio_files = set(
|
|
||||||
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
|
||||||
|
|
||||||
self.data = []
|
|
||||||
for row in df:
|
|
||||||
self.data.append({
|
|
||||||
'name': row['name'],
|
|
||||||
'caption': row['caption'],
|
|
||||||
})
|
|
||||||
|
|
||||||
self.audio_path = Path(audio_path)
|
|
||||||
self.csv_path = Path(csv_path)
|
|
||||||
|
|
||||||
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
||||||
return self.data[idx]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
@@ -1,237 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from torio.io import StreamingMediaDecoder
|
|
||||||
|
|
||||||
from selva_core.data.av_utils import normalize_video_chunk
|
|
||||||
from selva_core.utils.dist_utils import local_rank
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
_CLIP_SIZE = 384
|
|
||||||
_CLIP_FPS = 8.0
|
|
||||||
|
|
||||||
_SYNC_SIZE = 224
|
|
||||||
_SYNC_FPS = 25.0
|
|
||||||
|
|
||||||
|
|
||||||
class VideoDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_root: Union[str, Path],
|
|
||||||
*,
|
|
||||||
duration_sec: float = 8.0,
|
|
||||||
clip_video_required: bool = False,
|
|
||||||
):
|
|
||||||
self.video_root = Path(video_root)
|
|
||||||
self.duration_sec = duration_sec
|
|
||||||
self.clip_video_required = clip_video_required
|
|
||||||
|
|
||||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
|
||||||
self.sync_transform = v2.Compose([
|
|
||||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
# v2.CenterCrop(_SYNC_SIZE),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
|
|
||||||
if self.clip_video_required:
|
|
||||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
|
||||||
self.clip_transform = v2.Compose([
|
|
||||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
])
|
|
||||||
|
|
||||||
# to be implemented by subclasses
|
|
||||||
self.captions = {}
|
|
||||||
self.negative_captions = {}
|
|
||||||
self.videos = sorted(list(self.captions.keys()))
|
|
||||||
|
|
||||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
video_id = self.videos[idx]
|
|
||||||
caption = self.captions[video_id]
|
|
||||||
negative_caption = self.negative_captions.get(video_id, None)
|
|
||||||
|
|
||||||
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
|
||||||
frame_rate=_SYNC_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
if self.clip_video_required:
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
|
||||||
frame_rate=_CLIP_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
|
|
||||||
reader.fill_buffer()
|
|
||||||
data_chunk = reader.pop_chunks()
|
|
||||||
|
|
||||||
sync_chunk = data_chunk[0]
|
|
||||||
if sync_chunk is None:
|
|
||||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
|
||||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
|
||||||
n_tolerance_frame=3, desc=video_id)
|
|
||||||
sync_chunk = self.sync_transform(sync_chunk)
|
|
||||||
|
|
||||||
if self.clip_video_required:
|
|
||||||
clip_chunk = data_chunk[1]
|
|
||||||
if clip_chunk is None:
|
|
||||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
|
||||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
|
||||||
n_tolerance_frame=1, desc=video_id)
|
|
||||||
clip_chunk = self.clip_transform(clip_chunk)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'name': video_id,
|
|
||||||
'caption': caption,
|
|
||||||
'sync_video': sync_chunk,
|
|
||||||
}
|
|
||||||
if self.clip_video_required:
|
|
||||||
data['clip_video'] = clip_chunk
|
|
||||||
if negative_caption is not None:
|
|
||||||
data['negative_caption'] = negative_caption
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
try:
|
|
||||||
return self.sample(idx)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.captions)
|
|
||||||
|
|
||||||
|
|
||||||
class VGGSound(VideoDataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_root: Union[str, Path],
|
|
||||||
csv_path: Union[str, Path],
|
|
||||||
*,
|
|
||||||
duration_sec: float = 8.0,
|
|
||||||
clip_video_required: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(video_root, duration_sec=duration_sec,
|
|
||||||
clip_video_required=clip_video_required)
|
|
||||||
self.video_root = Path(video_root)
|
|
||||||
self.csv_path = Path(csv_path)
|
|
||||||
|
|
||||||
videos = sorted(os.listdir(self.video_root))
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {video_root}')
|
|
||||||
self.captions = {}
|
|
||||||
|
|
||||||
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
|
||||||
'split']).to_dict(orient='records')
|
|
||||||
|
|
||||||
videos_no_found = []
|
|
||||||
for row in df:
|
|
||||||
if row['split'] == 'test':
|
|
||||||
start_sec = int(row['sec'])
|
|
||||||
video_id = str(row['id'])
|
|
||||||
# this is how our videos are named
|
|
||||||
video_name = f'{video_id}_{start_sec:06d}'
|
|
||||||
if video_name + '.mp4' not in videos:
|
|
||||||
videos_no_found.append(video_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.captions[video_name] = row['caption']
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {video_root}')
|
|
||||||
log.info(f'{len(self.captions)} useable videos found')
|
|
||||||
if videos_no_found:
|
|
||||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
|
||||||
log.info(
|
|
||||||
'A small amount is expected, as not all videos are still available on YouTube')
|
|
||||||
|
|
||||||
self.videos = sorted(list(self.captions.keys()))
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceVideoData(VideoDataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_root: Union[str, Path],
|
|
||||||
jsonl_root: Union[str, Path],
|
|
||||||
*,
|
|
||||||
duration_sec: float = 10.0,
|
|
||||||
clip_video_required: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(video_root, duration_sec=duration_sec,
|
|
||||||
clip_video_required=clip_video_required)
|
|
||||||
self.video_root = Path(video_root)
|
|
||||||
self.jsonl_root = Path(jsonl_root)
|
|
||||||
|
|
||||||
videos = sorted(os.listdir(self.video_root))
|
|
||||||
videos = [v[:-4] for v in videos] # remove extensions
|
|
||||||
self.captions = {}
|
|
||||||
|
|
||||||
for v in videos:
|
|
||||||
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
|
||||||
data = json.load(f)
|
|
||||||
self.captions[v] = data['audio_prompt']
|
|
||||||
self.negative_captions[v] = data.get('negative_audio_prompt', None)
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {video_root}')
|
|
||||||
|
|
||||||
self.videos = videos
|
|
||||||
|
|
||||||
|
|
||||||
class VGGMonoAudioBench(VideoDataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
video_root: Union[str, Path],
|
|
||||||
csv_path: Union[str, Path],
|
|
||||||
*,
|
|
||||||
duration_sec: float = 8.0,
|
|
||||||
clip_video_required: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(video_root, duration_sec=duration_sec,
|
|
||||||
clip_video_required=clip_video_required)
|
|
||||||
self.video_root = Path(video_root)
|
|
||||||
self.csv_path = Path(csv_path)
|
|
||||||
|
|
||||||
videos = sorted(os.listdir(self.video_root))
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {video_root}')
|
|
||||||
self.captions = {}
|
|
||||||
self.negative_captions = {}
|
|
||||||
|
|
||||||
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
|
|
||||||
).to_dict(orient='records')
|
|
||||||
|
|
||||||
videos_no_found = []
|
|
||||||
for row in df:
|
|
||||||
video_name = str(Path(row['file_name']).stem)
|
|
||||||
if video_name + '.mp4' not in videos:
|
|
||||||
videos_no_found.append(video_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.captions[video_name] = row['label']
|
|
||||||
self.negative_captions[video_name] = row['paired_label']
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {video_root}')
|
|
||||||
log.info(f'{len(self.captions)} useable videos found')
|
|
||||||
if videos_no_found:
|
|
||||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
|
|
||||||
|
|
||||||
self.videos = sorted(list(self.captions.keys()))
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from torio.io import StreamingMediaDecoder
|
|
||||||
|
|
||||||
from selva_core.data.av_utils import normalize_video_chunk
|
|
||||||
from selva_core.utils.dist_utils import local_rank
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
_CLIP_SIZE = 384
|
|
||||||
_CLIP_FPS = 8.0
|
|
||||||
|
|
||||||
_SYNC_SIZE = 224
|
|
||||||
_SYNC_FPS = 25.0
|
|
||||||
|
|
||||||
|
|
||||||
class VGGSound(Dataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
root: Union[str, Path],
|
|
||||||
*,
|
|
||||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
|
||||||
audio_required: bool = True,
|
|
||||||
sample_rate: int = 16_000,
|
|
||||||
duration_sec: float = 8.0,
|
|
||||||
audio_samples: Optional[int] = None,
|
|
||||||
normalize_audio: bool = False,
|
|
||||||
clip_video_required: bool = True,
|
|
||||||
):
|
|
||||||
self.root = Path(root)
|
|
||||||
self.audio_required = audio_required
|
|
||||||
if audio_required:
|
|
||||||
self.normalize_audio = normalize_audio
|
|
||||||
if audio_samples is None:
|
|
||||||
self.audio_samples = int(sample_rate * duration_sec)
|
|
||||||
else:
|
|
||||||
self.audio_samples = audio_samples
|
|
||||||
effective_duration = audio_samples / sample_rate
|
|
||||||
# make sure the duration is close enough, within 15ms
|
|
||||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
|
||||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
|
||||||
self.clip_video_required = clip_video_required
|
|
||||||
|
|
||||||
videos = sorted(os.listdir(self.root))
|
|
||||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
|
||||||
self.labels = {}
|
|
||||||
self.videos = []
|
|
||||||
missing_videos = []
|
|
||||||
|
|
||||||
# read the tsv for subset information
|
|
||||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
|
||||||
for record in df_list:
|
|
||||||
id = record['id']
|
|
||||||
label = record['label']
|
|
||||||
if id in videos:
|
|
||||||
self.labels[id] = label
|
|
||||||
self.videos.append(id)
|
|
||||||
else:
|
|
||||||
missing_videos.append(id)
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {root}')
|
|
||||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
|
||||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
|
||||||
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.duration_sec = duration_sec
|
|
||||||
|
|
||||||
if audio_required:
|
|
||||||
self.expected_audio_length = self.audio_samples
|
|
||||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
|
||||||
if clip_video_required:
|
|
||||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
|
||||||
|
|
||||||
self.sync_transform = v2.Compose([
|
|
||||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
# v2.CenterCrop(_SYNC_SIZE),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
|
|
||||||
if clip_video_required:
|
|
||||||
self.clip_transform = v2.Compose([
|
|
||||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
])
|
|
||||||
if audio_required:
|
|
||||||
self.resampler = {}
|
|
||||||
|
|
||||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
video_id = self.videos[idx]
|
|
||||||
|
|
||||||
label = self.labels[video_id]
|
|
||||||
|
|
||||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
|
||||||
frame_rate=_SYNC_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
if self.audio_required:
|
|
||||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
|
||||||
if self.clip_video_required:
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
|
||||||
frame_rate=_CLIP_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
|
|
||||||
reader.fill_buffer()
|
|
||||||
data_chunk = reader.pop_chunks()
|
|
||||||
|
|
||||||
sync_chunk = data_chunk[0]
|
|
||||||
if sync_chunk is None:
|
|
||||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
|
||||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
|
||||||
n_tolerance_frame=3, desc=video_id)
|
|
||||||
sync_chunk = self.sync_transform(sync_chunk)
|
|
||||||
|
|
||||||
if self.audio_required:
|
|
||||||
audio_chunk = data_chunk[1]
|
|
||||||
|
|
||||||
if self.clip_video_required:
|
|
||||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
|
||||||
if clip_chunk is None:
|
|
||||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
|
||||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
|
||||||
n_tolerance_frame=1, desc=video_id)
|
|
||||||
clip_chunk = self.clip_transform(clip_chunk)
|
|
||||||
|
|
||||||
# process audio
|
|
||||||
if self.audio_required:
|
|
||||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
|
||||||
audio_chunk = audio_chunk.transpose(0, 1)
|
|
||||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
|
||||||
if self.normalize_audio:
|
|
||||||
abs_max = audio_chunk.abs().max()
|
|
||||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
|
||||||
if abs_max <= 1e-6:
|
|
||||||
raise RuntimeError(f'Audio is silent {video_id}')
|
|
||||||
|
|
||||||
# resample
|
|
||||||
if sample_rate == self.sample_rate:
|
|
||||||
audio_chunk = audio_chunk
|
|
||||||
else:
|
|
||||||
if sample_rate not in self.resampler:
|
|
||||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
|
||||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
|
||||||
sample_rate,
|
|
||||||
self.sample_rate,
|
|
||||||
lowpass_filter_width=64,
|
|
||||||
rolloff=0.9475937167399596,
|
|
||||||
resampling_method='sinc_interp_kaiser',
|
|
||||||
beta=14.769656459379492,
|
|
||||||
)
|
|
||||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
|
||||||
|
|
||||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
|
||||||
raise RuntimeError(f'Audio too short {video_id}')
|
|
||||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'id': video_id,
|
|
||||||
'caption': label,
|
|
||||||
'sync_video': sync_chunk,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.audio_required:
|
|
||||||
data['audio'] = audio_chunk
|
|
||||||
if self.clip_video_required:
|
|
||||||
data['clip_video'] = clip_chunk
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
try:
|
|
||||||
return self.sample(idx)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.labels)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import open_clip
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class WavTextClipsDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
root: Union[str, Path],
|
|
||||||
*,
|
|
||||||
captions_tsv: Union[str, Path],
|
|
||||||
clips_tsv: Union[str, Path],
|
|
||||||
sample_rate: int,
|
|
||||||
num_samples: int,
|
|
||||||
normalize_audio: bool = False,
|
|
||||||
reject_silent: bool = False,
|
|
||||||
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
|
||||||
):
|
|
||||||
self.root = Path(root)
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.num_samples = num_samples
|
|
||||||
self.normalize_audio = normalize_audio
|
|
||||||
self.reject_silent = reject_silent
|
|
||||||
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
|
||||||
|
|
||||||
audios = sorted(os.listdir(self.root))
|
|
||||||
audios = set([
|
|
||||||
Path(audio).stem for audio in audios
|
|
||||||
if audio.endswith('.wav') or audio.endswith('.flac')
|
|
||||||
])
|
|
||||||
self.captions = {}
|
|
||||||
|
|
||||||
# read the caption tsv
|
|
||||||
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
|
||||||
for record in df_list:
|
|
||||||
id = record['id']
|
|
||||||
caption = record['caption']
|
|
||||||
self.captions[id] = caption
|
|
||||||
|
|
||||||
# read the clip tsv
|
|
||||||
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
|
||||||
'id': str,
|
|
||||||
'name': str
|
|
||||||
}).to_dict('records')
|
|
||||||
self.clips = []
|
|
||||||
for record in df_list:
|
|
||||||
record['id'] = record['id']
|
|
||||||
record['name'] = record['name']
|
|
||||||
id = record['id']
|
|
||||||
name = record['name']
|
|
||||||
record['caption'] = self.captions[name]
|
|
||||||
self.clips.append(record)
|
|
||||||
|
|
||||||
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
|
||||||
|
|
||||||
self.resampler = {}
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
|
||||||
try:
|
|
||||||
clip = self.clips[idx]
|
|
||||||
audio_name = clip['name']
|
|
||||||
audio_id = clip['id']
|
|
||||||
caption = clip['caption']
|
|
||||||
start_sample = clip['start_sample']
|
|
||||||
end_sample = clip['end_sample']
|
|
||||||
|
|
||||||
audio_path = self.root / f'{audio_name}.flac'
|
|
||||||
if not audio_path.exists():
|
|
||||||
audio_path = self.root / f'{audio_name}.wav'
|
|
||||||
assert audio_path.exists()
|
|
||||||
|
|
||||||
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
|
||||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
|
||||||
abs_max = audio_chunk.abs().max()
|
|
||||||
if self.normalize_audio:
|
|
||||||
audio_chunk = audio_chunk / abs_max * 0.95
|
|
||||||
|
|
||||||
if self.reject_silent and abs_max < 1e-6:
|
|
||||||
log.warning(f'Rejecting silent audio')
|
|
||||||
return None
|
|
||||||
|
|
||||||
audio_chunk = audio_chunk[start_sample:end_sample]
|
|
||||||
|
|
||||||
# resample
|
|
||||||
if sample_rate == self.sample_rate:
|
|
||||||
audio_chunk = audio_chunk
|
|
||||||
else:
|
|
||||||
if sample_rate not in self.resampler:
|
|
||||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
|
||||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
|
||||||
sample_rate,
|
|
||||||
self.sample_rate,
|
|
||||||
lowpass_filter_width=64,
|
|
||||||
rolloff=0.9475937167399596,
|
|
||||||
resampling_method='sinc_interp_kaiser',
|
|
||||||
beta=14.769656459379492,
|
|
||||||
)
|
|
||||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
|
||||||
|
|
||||||
if audio_chunk.shape[0] < self.num_samples:
|
|
||||||
raise ValueError('Audio is too short')
|
|
||||||
audio_chunk = audio_chunk[:self.num_samples]
|
|
||||||
|
|
||||||
tokens = self.tokenizer([caption])[0]
|
|
||||||
|
|
||||||
output = {
|
|
||||||
'waveform': audio_chunk,
|
|
||||||
'id': audio_id,
|
|
||||||
'caption': caption,
|
|
||||||
'tokens': tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
return output
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f'Error reading {audio_path}: {e}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.clips)
|
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
""" Embedding Mixup
|
|
||||||
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
|
||||||
"""
|
|
||||||
from typing import Literal, Tuple, Union, List, Optional
|
|
||||||
from functools import partial
|
|
||||||
import gc
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataloader import default_collate
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from einops import rearrange
|
|
||||||
from omegaconf import DictConfig
|
|
||||||
|
|
||||||
from selva_core.data.vgg_sound import _SYNC_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
class MixupBase:
|
|
||||||
""" Base class for mixup on either data or feature domain.
|
|
||||||
Applies different params to each element or whole batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
||||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
||||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
||||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
||||||
prob (float): Probability of applying mixup per batch or element
|
|
||||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
||||||
eps (float): Small epsilon value to avoid zero lambda
|
|
||||||
"""
|
|
||||||
def __init__(self, generator:torch.Generator,
|
|
||||||
*,
|
|
||||||
modality:Literal['video', 'audio', 'both'],
|
|
||||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
||||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
||||||
eps:float=0.05
|
|
||||||
):
|
|
||||||
self.modality = modality
|
|
||||||
self.mixup_lambda:float = mixup_lambda
|
|
||||||
self.mixup_alpha:float = mixup_alpha
|
|
||||||
self.mix_prob:float = prob
|
|
||||||
self.mode:str = mode
|
|
||||||
self.eps:float = eps
|
|
||||||
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
|
||||||
if generator.device.type == 'cuda':
|
|
||||||
self.generator_cuda = generator
|
|
||||||
generator_seed = generator.initial_seed()
|
|
||||||
self.generator = torch.Generator(device='cpu')
|
|
||||||
self.generator.manual_seed(generator_seed)
|
|
||||||
else:
|
|
||||||
self.generator = generator
|
|
||||||
|
|
||||||
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
|
||||||
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
|
||||||
if not self.mixup_alpha >= 0.:
|
|
||||||
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
|
||||||
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
|
||||||
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
|
||||||
|
|
||||||
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
|
||||||
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
|
||||||
if self.mixup_enabled:
|
|
||||||
if self.mixup_lambda < 1.: # constant lambda
|
|
||||||
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
|
||||||
elif self.mixup_alpha > 0.: # sampled lambda
|
|
||||||
# Use torch's beta distribution with generator
|
|
||||||
lam_mix = torch.distributions.Beta(
|
|
||||||
torch.tensor([self.mixup_alpha]),
|
|
||||||
torch.tensor([self.mixup_alpha]),
|
|
||||||
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
|
||||||
else:
|
|
||||||
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
|
||||||
lam_mix[lam_mix < self.eps] = self.eps
|
|
||||||
|
|
||||||
# Use torch's random with generator for the random comparison
|
|
||||||
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
|
||||||
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
|
||||||
return lam
|
|
||||||
|
|
||||||
def _params_per_batch(self) -> float:
|
|
||||||
lam:float = 1.
|
|
||||||
if self.mixup_enabled:
|
|
||||||
if self.mixup_lambda < 1.: # constant lambda
|
|
||||||
lam = self.mixup_lambda
|
|
||||||
elif self.mixup_alpha > 0.: # sampled lambda
|
|
||||||
lam = torch.distributions.Beta(
|
|
||||||
torch.tensor([self.mixup_alpha]),
|
|
||||||
torch.tensor([self.mixup_alpha]),
|
|
||||||
).sample().item()
|
|
||||||
else:
|
|
||||||
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
|
||||||
if lam < self.eps: lam = self.eps
|
|
||||||
lam = float(lam)
|
|
||||||
return lam
|
|
||||||
|
|
||||||
|
|
||||||
class DataMixupCollate(MixupBase):
|
|
||||||
""" Mixup video in data domain.
|
|
||||||
Applies different params to each element or whole batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
||||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
||||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
||||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
||||||
prob (float): Probability of applying mixup per batch or element
|
|
||||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
||||||
eps (float): Small epsilon value to avoid zero lambda
|
|
||||||
"""
|
|
||||||
def __init__(self, generator:torch.Generator,
|
|
||||||
*,
|
|
||||||
modality:Literal['video', 'audio', 'both']='video',
|
|
||||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
||||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
||||||
eps:float=0.05
|
|
||||||
):
|
|
||||||
super().__init__(generator, modality=modality,
|
|
||||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
|
||||||
mode=mode, eps=eps)
|
|
||||||
|
|
||||||
self.source_video_key= 'sync_video'
|
|
||||||
self.source_audio_key = 'audio'
|
|
||||||
self.target_video_key = 'sync_video_mixed'
|
|
||||||
self.target_audio_key = 'audio_mixed'
|
|
||||||
|
|
||||||
if not mode == 'batch':
|
|
||||||
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
|
||||||
self.sync_transform = v2.Compose([
|
|
||||||
v2.CenterCrop(_SYNC_SIZE),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
|
|
||||||
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
|
||||||
# only batch mode supported
|
|
||||||
batch_size:int = len(batch)
|
|
||||||
lam:float = self._params_per_batch()
|
|
||||||
|
|
||||||
if lam == 1.:
|
|
||||||
# no mixup, just return
|
|
||||||
for i in range(batch_size):
|
|
||||||
batch[i][target_key] = batch[i][source_key]
|
|
||||||
return lam
|
|
||||||
|
|
||||||
# Randomly choose between horizontal and vertical resizing using
|
|
||||||
orig_size = int(lam * _SYNC_SIZE)
|
|
||||||
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
|
||||||
if is_horizontal:
|
|
||||||
# Horizontal resize
|
|
||||||
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
|
||||||
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
|
||||||
else:
|
|
||||||
# Vertical resize
|
|
||||||
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
|
||||||
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
|
||||||
sync_resize_orig = v2.Compose([
|
|
||||||
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
])
|
|
||||||
sync_resize_pair = v2.Compose([
|
|
||||||
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
])
|
|
||||||
|
|
||||||
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
|
||||||
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
|
||||||
# (B, T, C, H, W)
|
|
||||||
# pass through resize, transform and concat
|
|
||||||
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
|
||||||
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
|
||||||
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
|
||||||
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
|
||||||
|
|
||||||
num_mixup = int(self.mix_prob * batch_size)
|
|
||||||
for i in range(num_mixup):
|
|
||||||
batch[i][target_key] = batch_videos_concat[i]
|
|
||||||
for i in range(num_mixup, batch_size):
|
|
||||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
|
||||||
|
|
||||||
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return lam
|
|
||||||
|
|
||||||
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
|
||||||
normalize:bool = True) -> float:
|
|
||||||
# assume source_key audios are normalized
|
|
||||||
batch_size:int = len(batch)
|
|
||||||
lam:float = self._params_per_batch()
|
|
||||||
|
|
||||||
if lam == 1.:
|
|
||||||
# no mixup, just return
|
|
||||||
for i in range(batch_size):
|
|
||||||
batch[i][target_key] = batch[i][source_key]
|
|
||||||
return lam
|
|
||||||
|
|
||||||
num_mixup = int(self.mix_prob * batch_size)
|
|
||||||
for i in range(num_mixup):
|
|
||||||
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
|
||||||
if normalize:
|
|
||||||
source_abs_max = batch[i][source_key].abs().max()
|
|
||||||
target_abs_max = batch[i][target_key].abs().max()
|
|
||||||
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
|
||||||
for i in range(num_mixup, batch_size):
|
|
||||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
|
||||||
|
|
||||||
return lam
|
|
||||||
|
|
||||||
def __call__(self, batch:list, _=None) -> torch.tensor:
|
|
||||||
batch_size:int = len(batch)
|
|
||||||
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
|
||||||
half = 'half' in self.mode
|
|
||||||
if half:
|
|
||||||
batch_size //= 2
|
|
||||||
|
|
||||||
if self.modality == 'video' or self.modality == 'both':
|
|
||||||
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
|
||||||
if self.modality == 'audio' or self.modality == 'both':
|
|
||||||
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
|
||||||
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
|
||||||
|
|
||||||
return default_collate(batch)
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMixup(MixupBase):
|
|
||||||
""" Mixup video in feature domain.
|
|
||||||
Applies different params to each element or whole batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
|
||||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
|
||||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
|
||||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
|
||||||
prob (float): Probability of applying mixup per batch or element
|
|
||||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
|
||||||
eps (float): Small epsilon value to avoid zero lambda
|
|
||||||
"""
|
|
||||||
def __init__(self, generator:torch.Generator,
|
|
||||||
*,
|
|
||||||
modality:Literal['video', 'audio', 'both']='video',
|
|
||||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
|
||||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
|
||||||
eps:float=0.05
|
|
||||||
):
|
|
||||||
super().__init__(generator, modality=modality,
|
|
||||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
|
||||||
mode=mode, eps=eps)
|
|
||||||
self.source_video_key= 'sync_f_vid_orig'
|
|
||||||
self.source_audio_key = 'sync_f_aud_orig'
|
|
||||||
self.target_video_key = 'sync_f_vid_mixed'
|
|
||||||
self.target_audio_key = 'sync_f_aud_mixed'
|
|
||||||
|
|
||||||
def _mix_elem_collate(self, batch:dict,
|
|
||||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
|
||||||
half:bool=False) -> torch.tensor:
|
|
||||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
||||||
batch_size:int = len(batch['id'])
|
|
||||||
num_elem:int = batch_size // 2 if half else batch_size
|
|
||||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
|
||||||
|
|
||||||
indices = torch.arange(num_elem)
|
|
||||||
mix_indices = batch_size - indices - 1
|
|
||||||
mix_mask = lam_batch < 1
|
|
||||||
active_indices = indices[mix_mask]
|
|
||||||
active_mix_indices = mix_indices[mix_mask]
|
|
||||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
|
||||||
for target_key, source_key in zip(target_keys, source_keys):
|
|
||||||
batch[target_key][active_indices] = (
|
|
||||||
batch[source_key][active_indices] * active_lambdas +
|
|
||||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
|
||||||
)
|
|
||||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
|
||||||
if half:
|
|
||||||
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
|
||||||
return lam_batch.unsqueeze(1)
|
|
||||||
|
|
||||||
def _mix_pair_collate(self, batch:dict,
|
|
||||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
|
||||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
||||||
batch_size:int = len(batch['id'])
|
|
||||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
|
||||||
|
|
||||||
indices = torch.arange(batch_size // 2)
|
|
||||||
mix_indices = batch_size - indices - 1
|
|
||||||
mix_mask = lam_batch < 1
|
|
||||||
active_indices = indices[mix_mask]
|
|
||||||
active_mix_indices = mix_indices[mix_mask]
|
|
||||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
|
||||||
for target_key, source_key in zip(target_keys, source_keys):
|
|
||||||
batch[target_key][active_indices] = (
|
|
||||||
batch[source_key][active_indices] * active_lambdas +
|
|
||||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
|
||||||
)
|
|
||||||
batch[target_key][active_mix_indices] = (
|
|
||||||
batch[source_key][active_mix_indices] * active_lambdas +
|
|
||||||
batch[source_key][active_indices] * (1 - active_lambdas)
|
|
||||||
)
|
|
||||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
|
||||||
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
|
||||||
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
|
||||||
return lam_batch.unsqueeze(1)
|
|
||||||
|
|
||||||
def _mix_batch_collate(self, batch:dict,
|
|
||||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
|
||||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
|
||||||
lam:float = self._params_per_batch()
|
|
||||||
|
|
||||||
for target_key, source_key in zip(target_keys, source_keys):
|
|
||||||
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
|
||||||
flipped_source = torch.flip(batch[source_key], dims=[0])
|
|
||||||
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
|
||||||
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
|
||||||
return lam
|
|
||||||
|
|
||||||
def __call__(self, batch:dict, _=None) -> None:
|
|
||||||
batch_size:int = len(batch['id'])
|
|
||||||
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
|
||||||
half = 'half' in self.mode
|
|
||||||
if half:
|
|
||||||
batch_size //= 2
|
|
||||||
|
|
||||||
# Mixup
|
|
||||||
if self.mode == 'elem' or self.mode == 'half':
|
|
||||||
collate_fn = partial(self._mix_elem_collate, half=half)
|
|
||||||
elif self.mode == 'pair':
|
|
||||||
collate_fn = self._mix_pair_collate
|
|
||||||
else:
|
|
||||||
collate_fn = self._mix_batch_collate
|
|
||||||
|
|
||||||
if self.modality == 'both':
|
|
||||||
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
|
||||||
elif self.modality == 'video':
|
|
||||||
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
|
||||||
elif self.modality == 'audio':
|
|
||||||
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
|
||||||
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
|
||||||
|
|
||||||
# return batch
|
|
||||||
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
import bisect
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
|
|
||||||
|
|
||||||
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
|
||||||
class MultiModalDataset(Dataset):
|
|
||||||
datasets: list[Dataset]
|
|
||||||
cumulative_sizes: list[int]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def cumsum(sequence):
|
|
||||||
r, s = [], 0
|
|
||||||
for e in sequence:
|
|
||||||
l = len(e)
|
|
||||||
r.append(l + s)
|
|
||||||
s += l
|
|
||||||
return r
|
|
||||||
|
|
||||||
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
|
||||||
super().__init__()
|
|
||||||
self.video_datasets = list(video_datasets)
|
|
||||||
self.audio_datasets = list(audio_datasets)
|
|
||||||
self.datasets = self.video_datasets + self.audio_datasets
|
|
||||||
|
|
||||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.cumulative_sizes[-1]
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if idx < 0:
|
|
||||||
if -idx > len(self):
|
|
||||||
raise ValueError("absolute value of index should not exceed dataset length")
|
|
||||||
idx = len(self) + idx
|
|
||||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
|
||||||
if dataset_idx == 0:
|
|
||||||
sample_idx = idx
|
|
||||||
else:
|
|
||||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
|
||||||
return self.datasets[dataset_idx][sample_idx]
|
|
||||||
|
|
||||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return self.video_datasets[0].compute_latent_stats()
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from tensordict import MemoryMappedTensor
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from selva_core.utils.dist_utils import local_rank, world_size
|
|
||||||
|
|
||||||
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
|
||||||
shm_path = Path('/dev/shm')
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
def reseed(seed):
|
|
||||||
random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def local_scatter_torch(obj: Optional[Any]):
|
|
||||||
if world_size == 1:
|
|
||||||
# Just one worker. Do nothing.
|
|
||||||
return obj
|
|
||||||
|
|
||||||
array = [obj] * world_size
|
|
||||||
target_array = [None]
|
|
||||||
if local_rank == 0:
|
|
||||||
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
|
||||||
else:
|
|
||||||
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
|
||||||
return target_array[0]
|
|
||||||
|
|
||||||
|
|
||||||
class ShardDataset(Dataset):
|
|
||||||
|
|
||||||
def __init__(self, root):
|
|
||||||
self.root = root
|
|
||||||
self.shards = sorted(os.listdir(root))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.shards)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tmp_dir(in_memory: bool) -> Path:
|
|
||||||
return shm_path if in_memory else scratch_path
|
|
||||||
|
|
||||||
|
|
||||||
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
|
||||||
in_memory: bool) -> MemoryMappedTensor:
|
|
||||||
if local_rank == 0:
|
|
||||||
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
|
||||||
log.info(f'Loading shards from {data_path} into {f.name}...')
|
|
||||||
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
|
||||||
data = share_tensor_to_all(data)
|
|
||||||
torch.distributed.barrier()
|
|
||||||
f.close() # why does the context manager not close the file for me?
|
|
||||||
else:
|
|
||||||
log.info('Waiting for the data to be shared with me...')
|
|
||||||
data = share_tensor_to_all(None)
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def load_shards(
|
|
||||||
data_path: Union[str, Path],
|
|
||||||
ids: list[int],
|
|
||||||
*,
|
|
||||||
tmp_file_path: str,
|
|
||||||
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
|
||||||
|
|
||||||
id_set = set(ids)
|
|
||||||
shards = sorted(os.listdir(data_path))
|
|
||||||
log.info(f'Found {len(shards)} shards in {data_path}.')
|
|
||||||
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
|
||||||
|
|
||||||
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
|
||||||
first_item = next(iter(first_shard.values()))
|
|
||||||
log.info(f'First item shape: {first_item.shape}')
|
|
||||||
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
|
||||||
dtype=torch.float32,
|
|
||||||
filename=tmp_file_path,
|
|
||||||
existsok=True)
|
|
||||||
total_count = 0
|
|
||||||
used_index = set()
|
|
||||||
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
|
||||||
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
|
||||||
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
|
||||||
for data in tqdm(loader, desc='Loading shards'):
|
|
||||||
for i, v in data.items():
|
|
||||||
if i not in id_set:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# tensor_index = ids.index(i)
|
|
||||||
tensor_index = id_indexing[i]
|
|
||||||
if tensor_index in used_index:
|
|
||||||
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
|
||||||
used_index.add(tensor_index)
|
|
||||||
mm_tensor[tensor_index] = v
|
|
||||||
total_count += 1
|
|
||||||
|
|
||||||
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
|
||||||
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
|
||||||
|
|
||||||
return mm_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
|
||||||
"""
|
|
||||||
x: the tensor to be shared; None if local_rank != 0
|
|
||||||
return: the shared tensor
|
|
||||||
"""
|
|
||||||
|
|
||||||
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
|
||||||
if world_size == 1:
|
|
||||||
return x
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
assert x is not None, 'x must not be None if local_rank == 0'
|
|
||||||
else:
|
|
||||||
assert x is None, 'x must be None if local_rank != 0'
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
filename = x.filename
|
|
||||||
meta_information = (filename, x.shape, x.dtype)
|
|
||||||
else:
|
|
||||||
meta_information = None
|
|
||||||
|
|
||||||
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
|
||||||
if local_rank == 0:
|
|
||||||
data = x
|
|
||||||
else:
|
|
||||||
data = MemoryMappedTensor.from_filename(filename=filename,
|
|
||||||
dtype=data_type,
|
|
||||||
shape=data_shape)
|
|
||||||
|
|
||||||
return data
|
|
||||||
@@ -1,299 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torchvision.transforms import v2
|
|
||||||
from torio.io import StreamingMediaDecoder
|
|
||||||
from tensordict import TensorDict
|
|
||||||
|
|
||||||
from selva_core.data.av_utils import normalize_video_chunk
|
|
||||||
from selva_core.utils.dist_utils import local_rank
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
|
||||||
|
|
||||||
_CLIP_SIZE = 384
|
|
||||||
_CLIP_FPS = 8.0
|
|
||||||
|
|
||||||
_SYNC_SIZE = 224
|
|
||||||
_SYNC_FPS = 25.0
|
|
||||||
|
|
||||||
|
|
||||||
class VGGSound(Dataset):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
root: Union[str, Path],
|
|
||||||
*,
|
|
||||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
|
||||||
for_generator: bool = True,
|
|
||||||
audio_required: bool = False,
|
|
||||||
sample_rate: int = 16_000,
|
|
||||||
duration_sec: float = 8.0,
|
|
||||||
audio_samples: Optional[int] = None,
|
|
||||||
normalize_audio: bool = False,
|
|
||||||
clip_video_required: bool = False,
|
|
||||||
mmap_dir: Union[str, Path] = None,
|
|
||||||
tsv_tsynch_path: Union[str, Path] = None,
|
|
||||||
mmap_tsync_dir: Union[str, Path] = None,
|
|
||||||
data_dim: dict[str, int] = None,
|
|
||||||
):
|
|
||||||
self.root = Path(root)
|
|
||||||
self.audio_required = audio_required
|
|
||||||
if audio_required:
|
|
||||||
self.normalize_audio = normalize_audio
|
|
||||||
if audio_samples is None:
|
|
||||||
self.audio_samples = int(sample_rate * duration_sec)
|
|
||||||
else:
|
|
||||||
self.audio_samples = audio_samples
|
|
||||||
effective_duration = audio_samples / sample_rate
|
|
||||||
# make sure the duration is close enough, within 15ms
|
|
||||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
|
||||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
|
||||||
self.clip_video_required = clip_video_required
|
|
||||||
self.for_generator = for_generator
|
|
||||||
|
|
||||||
videos = sorted(os.listdir(self.root))
|
|
||||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
|
||||||
self.labels = {}
|
|
||||||
self.videos = []
|
|
||||||
missing_videos = []
|
|
||||||
|
|
||||||
# read the tsv for subset information
|
|
||||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
|
||||||
for record in df_list:
|
|
||||||
id = record['id']
|
|
||||||
label = record['label']
|
|
||||||
if id in videos:
|
|
||||||
self.labels[id] = label
|
|
||||||
self.videos.append(id)
|
|
||||||
else:
|
|
||||||
missing_videos.append(id)
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'{len(videos)} videos found in {root}')
|
|
||||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
|
||||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
|
||||||
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.duration_sec = duration_sec
|
|
||||||
|
|
||||||
if audio_required:
|
|
||||||
self.expected_audio_length = self.audio_samples
|
|
||||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
|
||||||
if clip_video_required:
|
|
||||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
|
||||||
|
|
||||||
self.sync_transform = v2.Compose([
|
|
||||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
# v2.CenterCrop(_SYNC_SIZE),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
|
|
||||||
if clip_video_required:
|
|
||||||
self.clip_transform = v2.Compose([
|
|
||||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
||||||
v2.ToImage(),
|
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
|
||||||
])
|
|
||||||
if audio_required:
|
|
||||||
self.resampler = {}
|
|
||||||
|
|
||||||
# mmap
|
|
||||||
log.info(f'Loading precomputed mmap from {mmap_dir}')
|
|
||||||
mmap_dir = Path(mmap_dir)
|
|
||||||
td = TensorDict.load_memmap(mmap_dir)
|
|
||||||
log.info(f'Loaded precomputed mmap from {mmap_dir}')
|
|
||||||
self.sync_features = td['sync_features']
|
|
||||||
if for_generator:
|
|
||||||
self.mean = td['mean']
|
|
||||||
self.std = td['std']
|
|
||||||
self.text_clip_features = td['text_features']
|
|
||||||
if clip_video_required:
|
|
||||||
self.clip_features = td['clip_features']
|
|
||||||
else:
|
|
||||||
self.clip_features = None
|
|
||||||
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
|
|
||||||
|
|
||||||
mmap_tsync_dir = Path(mmap_tsync_dir)
|
|
||||||
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
|
|
||||||
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
|
|
||||||
self.text_features = td_tsync['text_features']
|
|
||||||
self.text_masks = td_tsync['text_masks']
|
|
||||||
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
|
|
||||||
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
|
|
||||||
|
|
||||||
if local_rank == 0:
|
|
||||||
log.info(f'Loaded {len(self)} samples.')
|
|
||||||
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
|
||||||
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
|
||||||
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
|
|
||||||
if for_generator:
|
|
||||||
log.info(f'Loaded mean: {self.mean.shape}.')
|
|
||||||
log.info(f'Loaded std: {self.std.shape}.')
|
|
||||||
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
|
|
||||||
if clip_video_required:
|
|
||||||
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
|
||||||
|
|
||||||
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
|
|
||||||
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
|
|
||||||
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
|
||||||
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
|
||||||
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
|
||||||
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
|
||||||
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
|
|
||||||
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
|
|
||||||
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
|
|
||||||
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
|
|
||||||
if for_generator:
|
|
||||||
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
|
|
||||||
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
|
|
||||||
assert self.std.shape[1] == data_dim['latent_seq_len'], \
|
|
||||||
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
|
|
||||||
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
|
|
||||||
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
|
|
||||||
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
|
|
||||||
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
|
|
||||||
if clip_video_required:
|
|
||||||
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
|
|
||||||
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
|
|
||||||
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
|
|
||||||
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
|
|
||||||
|
|
||||||
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
|
||||||
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
|
|
||||||
latents = self.mean
|
|
||||||
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
|
||||||
|
|
||||||
def get_memory_mapped_tensor(self) -> TensorDict:
|
|
||||||
td = TensorDict({
|
|
||||||
'sync_features': self.sync_features,
|
|
||||||
'text_features': self.text_features,
|
|
||||||
'text_masks': self.text_masks,
|
|
||||||
})
|
|
||||||
if self.for_generator:
|
|
||||||
td['mean'] = self.mean
|
|
||||||
td['std'] = self.std
|
|
||||||
td['text_clip_features'] = self.text_clip_features
|
|
||||||
if self.clip_video_required:
|
|
||||||
td['clip_features'] = self.clip_features
|
|
||||||
return td
|
|
||||||
|
|
||||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
video_id = self.videos[idx]
|
|
||||||
|
|
||||||
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
|
|
||||||
label = self.captions[video_id]
|
|
||||||
else:
|
|
||||||
label = self.labels[video_id]
|
|
||||||
|
|
||||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
|
||||||
frame_rate=_SYNC_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
if self.audio_required:
|
|
||||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
|
||||||
if self.clip_video_required:
|
|
||||||
reader.add_basic_video_stream(
|
|
||||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
|
||||||
frame_rate=_CLIP_FPS,
|
|
||||||
format='rgb24',
|
|
||||||
)
|
|
||||||
|
|
||||||
reader.fill_buffer()
|
|
||||||
data_chunk = reader.pop_chunks()
|
|
||||||
|
|
||||||
sync_chunk = data_chunk[0]
|
|
||||||
if sync_chunk is None:
|
|
||||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
|
||||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
|
||||||
n_tolerance_frame=3, desc=video_id)
|
|
||||||
sync_chunk = self.sync_transform(sync_chunk)
|
|
||||||
|
|
||||||
if self.audio_required:
|
|
||||||
audio_chunk = data_chunk[1]
|
|
||||||
|
|
||||||
if self.clip_video_required:
|
|
||||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
|
||||||
if clip_chunk is None:
|
|
||||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
|
||||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
|
||||||
n_tolerance_frame=1, desc=video_id)
|
|
||||||
clip_chunk = self.clip_transform(clip_chunk)
|
|
||||||
|
|
||||||
# process audio
|
|
||||||
if self.audio_required:
|
|
||||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
|
||||||
audio_chunk = audio_chunk.transpose(0, 1)
|
|
||||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
|
||||||
if self.normalize_audio:
|
|
||||||
abs_max = audio_chunk.abs().max()
|
|
||||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
|
||||||
if abs_max <= 1e-6:
|
|
||||||
raise RuntimeError(f'Audio is silent {video_id}')
|
|
||||||
|
|
||||||
# resample
|
|
||||||
if sample_rate == self.sample_rate:
|
|
||||||
audio_chunk = audio_chunk
|
|
||||||
else:
|
|
||||||
if sample_rate not in self.resampler:
|
|
||||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
|
||||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
|
||||||
sample_rate,
|
|
||||||
self.sample_rate,
|
|
||||||
lowpass_filter_width=64,
|
|
||||||
rolloff=0.9475937167399596,
|
|
||||||
resampling_method='sinc_interp_kaiser',
|
|
||||||
beta=14.769656459379492,
|
|
||||||
)
|
|
||||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
|
||||||
|
|
||||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
|
||||||
raise RuntimeError(f'Audio too short {video_id}')
|
|
||||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
|
||||||
|
|
||||||
data = {
|
|
||||||
'id': video_id,
|
|
||||||
'caption': label,
|
|
||||||
'sync_video': sync_chunk,
|
|
||||||
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
|
|
||||||
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
|
|
||||||
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
|
|
||||||
'video_exist': self.video_exist,
|
|
||||||
'text_exist': self.text_exist,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.for_generator:
|
|
||||||
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
|
|
||||||
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
|
|
||||||
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
|
|
||||||
|
|
||||||
if self.audio_required:
|
|
||||||
data['audio'] = audio_chunk
|
|
||||||
|
|
||||||
if self.clip_video_required:
|
|
||||||
data['clip_video'] = clip_chunk
|
|
||||||
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
|
||||||
try:
|
|
||||||
return self.sample(idx)
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.labels)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
from .autoencoder import AutoEncoderModule
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from selva_core.ext.autoencoder.vae import VAE, get_my_vae
|
|
||||||
from selva_core.ext.bigvgan import BigVGAN
|
|
||||||
from selva_core.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
|
||||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
|
||||||
|
|
||||||
|
|
||||||
class AutoEncoderModule(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
*,
|
|
||||||
vae_ckpt_path,
|
|
||||||
vocoder_ckpt_path: Optional[str] = None,
|
|
||||||
mode: Literal['16k', '44k'],
|
|
||||||
need_vae_encoder: bool = True):
|
|
||||||
super().__init__()
|
|
||||||
self.vae: VAE = get_my_vae(mode).eval()
|
|
||||||
vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
|
|
||||||
self.vae.load_state_dict(vae_state_dict)
|
|
||||||
self.vae.remove_weight_norm()
|
|
||||||
|
|
||||||
if mode == '16k':
|
|
||||||
assert vocoder_ckpt_path is not None
|
|
||||||
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
|
||||||
elif mode == '44k':
|
|
||||||
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
|
||||||
use_cuda_kernel=False)
|
|
||||||
self.vocoder.remove_weight_norm()
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Unknown mode: {mode}')
|
|
||||||
|
|
||||||
for param in self.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
if not need_vae_encoder:
|
|
||||||
del self.vae.encoder
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
|
||||||
return self.vae.encode(x)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.vae.decode(z)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.vocoder(spec)
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user