Compare commits
69 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 95136b53a0 | |||
| 8f31d00beb | |||
| 3ee1893e10 | |||
| c86258d48f | |||
| 8338560600 | |||
| 8ae0ba3c7d | |||
| 2b2b438307 | |||
| 39984f73c2 | |||
| 1f8cd6f930 | |||
| 20f8138146 | |||
| 09b3b94ddd | |||
| 3f67de694c | |||
| 423e174b88 | |||
| 4806daa4ca | |||
| 16b3eb11cc | |||
| 004ea63f62 | |||
| afb3242eca | |||
| 849f31e2a6 | |||
| 505d445eb3 | |||
| 8fade1b0e3 | |||
| ad57432803 | |||
| 43f732f904 | |||
| 6b9adf0816 | |||
| 52434a053a | |||
| 56c8d5d6b4 | |||
| b430953602 | |||
| 57cd3dd4b4 | |||
| f206a1b38c | |||
| 2f4641247a | |||
| 8e9114b92c | |||
| 63b4391573 | |||
| 89af5a468c | |||
| c88e27742c | |||
| cbcd154c96 | |||
| 1eb82d8050 | |||
| cde280049b | |||
| 437c62b28f | |||
| b519b042e2 | |||
| f28759f1e3 | |||
| 3dd6badfd9 | |||
| 8bb2fb7015 | |||
| f4a7292cde | |||
| bd53744e2d | |||
| 429810db5b | |||
| 57f56c04e2 | |||
| ff26d0b87d | |||
| 83b1da9520 | |||
| 679a607a85 | |||
| d495939367 | |||
| 982d66e078 | |||
| b4124f58b3 | |||
| 2c9d521565 | |||
| 28229d62ce | |||
| 92593189f0 | |||
| 614a2e02aa | |||
| 40388ba6de | |||
| 789e09535d | |||
| 4da4858e4a | |||
| ab8e1e5b7b | |||
| e3a3384727 | |||
| 9a985499e7 | |||
| 27b4424e1a | |||
| 0e417f4078 | |||
| 6474e2816c | |||
| c23d210ab2 | |||
| b59b657b6f | |||
| 578b501d38 | |||
| fe94438356 | |||
| 6bc3fd6443 |
@@ -0,0 +1,392 @@
|
|||||||
|
# LoRA Training for SelVA
|
||||||
|
|
||||||
|
LoRA lets you teach the model new or partially-known sound classes using a small set of video+audio pairs. Only ~10 MB of adapter weights are trained instead of the full 4.4 GB model.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Training is split into two steps:
|
||||||
|
|
||||||
|
1. **Dataset preparation** (in ComfyUI) — extract visual features from your video clips using the `SelVA Feature Extractor` node, and collect clean matching audio files.
|
||||||
|
2. **Training** (in ComfyUI or command line) — run the `SelVA LoRA Trainer` node or `train_lora.py`.
|
||||||
|
|
||||||
|
The training script only loads the generator and the VAE encoder. CLIP visual features and sync features come pre-computed from the `.npz` files, so Synchformer and T5 are not loaded during training, saving 3–4 GB of VRAM.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
Same environment as SelVA inference. Additional Python packages:
|
||||||
|
|
||||||
|
```
|
||||||
|
torchaudio
|
||||||
|
soundfile
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 1 — Prepare the dataset
|
||||||
|
|
||||||
|
### 1.1 Extract visual features in ComfyUI
|
||||||
|
|
||||||
|
For each video clip you want to train on:
|
||||||
|
|
||||||
|
1. Load the video with a VHS LoadVideo node.
|
||||||
|
2. Connect it to **SelVA Feature Extractor**.
|
||||||
|
3. Set **`cache_dir`** to a dedicated dataset folder, e.g. `dataset/my_sound`.
|
||||||
|
4. Set **`name`** to a short descriptive label, e.g. `dog_bark`. The node will save `dog_bark_001.npz`, then `dog_bark_002.npz`, etc. automatically as you process more clips.
|
||||||
|
5. Set the **`prompt`** to describe the sound (e.g. `a dog barking`). This prompt conditions the Synchformer sync features — be as specific as possible (see prompt guide below).
|
||||||
|
6. Optionally connect a **mask** to isolate the sound source in frame (strongly recommended when multiple objects are visible — see masking note below).
|
||||||
|
|
||||||
|
> **Tip:** The prompt used for feature extraction conditions the *visual sync features*. You can use a different, more precise prompt at training time — see Step 2.
|
||||||
|
|
||||||
|
### Prompt guide
|
||||||
|
|
||||||
|
The prompt is not just a label — it directly shapes what the Synchformer pays attention to in the video. Imprecise prompts produce unfocused sync features, which the LoRA then has to compensate for, often introducing noise.
|
||||||
|
|
||||||
|
**Good prompts are specific about:**
|
||||||
|
- The sound source (what object is making the sound)
|
||||||
|
- The acoustic character (loud/quiet, sharp/soft, wet/dry)
|
||||||
|
- The action producing the sound (if applicable)
|
||||||
|
|
||||||
|
| Sound | Weak prompt | Strong prompt |
|
||||||
|
|---|---|---|
|
||||||
|
| Dog bark | `dog` | `a large dog barking loudly` |
|
||||||
|
| Footsteps | `walking` | `heavy boots on a wooden floor` |
|
||||||
|
| Water | `water` | `water dripping into a metal bucket` |
|
||||||
|
| Explosion | `explosion` | `a large explosion with deep bass rumble` |
|
||||||
|
| Door | `door` | `a heavy wooden door slamming shut` |
|
||||||
|
|
||||||
|
**Rules of thumb:**
|
||||||
|
- Describe the *sound*, not the visual scene. `a person hitting a drum` is better than `a drummer on stage`.
|
||||||
|
- Keep prompts consistent across all clips for the same sound class. Mixing `a dog barking` and `loud barking dog` in the same dataset creates conflicting sync features.
|
||||||
|
- Avoid negations (`no background noise`) — the model does not understand negations in sync feature conditioning.
|
||||||
|
|
||||||
|
### Masking note
|
||||||
|
|
||||||
|
If the video frame contains multiple moving objects, CLIP and sync features will be diluted by irrelevant motion. Use a segmentation mask (SAM2 or Grounding DINO+SAM) to isolate the sound source:
|
||||||
|
|
||||||
|
- Connect the mask to the **`mask`** input on SelVA Feature Extractor.
|
||||||
|
- Leave `mask_strength` at `1.0` for clean isolation; lower it only if the masked region is very small and the model loses context.
|
||||||
|
- Re-extract features with a mask even if you already have `.npz` files — better features directly reduce training noise.
|
||||||
|
|
||||||
|
### 1.2 Collect clean audio
|
||||||
|
|
||||||
|
For each `.npz` file, place a matching audio file with the **same filename stem** in the same directory:
|
||||||
|
|
||||||
|
```
|
||||||
|
dataset/my_sound/
|
||||||
|
dog_bark_001.npz ← from SelVA Feature Extractor
|
||||||
|
dog_bark_001.wav ← clean isolated audio recording
|
||||||
|
dog_bark_002.npz
|
||||||
|
dog_bark_002.wav
|
||||||
|
dog_bark_003.npz
|
||||||
|
dog_bark_003.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported audio formats: `.wav`, `.flac`, `.ogg`, `.aiff`, `.aif`
|
||||||
|
|
||||||
|
> `.mp3` is not recommended — lossy compression degrades training quality. Use `.flac` or `.wav`.
|
||||||
|
|
||||||
|
The audio will be automatically resampled and trimmed/padded to match the model's expected duration. Use clean, isolated recordings — no background noise.
|
||||||
|
|
||||||
|
### 1.3 Optional: prompts.txt
|
||||||
|
|
||||||
|
If you want a different prompt at training time than the one embedded in the `.npz`, create a `prompts.txt` file in the dataset directory:
|
||||||
|
|
||||||
|
```
|
||||||
|
# One line per file: filename: prompt text
|
||||||
|
dog_bark.npz: a large dog barking aggressively
|
||||||
|
dog_bark_001.npz: a dog barking in the distance
|
||||||
|
```
|
||||||
|
|
||||||
|
Priority: `prompts.txt` > prompt embedded in `.npz` > directory name as fallback.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 2 — Train
|
||||||
|
|
||||||
|
### Option A — SelVA LoRA Trainer node (ComfyUI)
|
||||||
|
|
||||||
|
Connect the node and set parameters directly in the UI. The node outputs the trained model ready to wire into the Sampler, and saves loss curve images to the output directory.
|
||||||
|
|
||||||
|
```
|
||||||
|
SelVA Model Loader → SelVA LoRA Trainer → SelVA Sampler
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option B — Command line
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train_lora.py \
|
||||||
|
--data_dir dataset/my_sound \
|
||||||
|
--output_dir lora_output/my_sound \
|
||||||
|
--variant large_44k \
|
||||||
|
--selva_dir /path/to/ComfyUI/models/selva \
|
||||||
|
--rank 16 \
|
||||||
|
--steps 4000 \
|
||||||
|
--batch_size 4 \
|
||||||
|
--lr 1e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will:
|
||||||
|
1. Load the VAE, CLIP text encoder, and generator.
|
||||||
|
2. Pre-load all clips (audio encoded to latents, features loaded from `.npz`).
|
||||||
|
3. Train LoRA adapters for the specified number of steps.
|
||||||
|
4. Save a checkpoint every `--save_every` steps, a final `adapter_final.pt`, and loss curve images.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CLI Reference
|
||||||
|
|
||||||
|
| Argument | Default | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `--data_dir` | required | Directory containing `.npz` + audio pairs |
|
||||||
|
| `--output_dir` | `lora_output` | Where to save adapter checkpoints |
|
||||||
|
| `--variant` | `large_44k` | Model variant: `small_16k`, `small_44k`, `medium_44k`, `large_44k` |
|
||||||
|
| `--selva_dir` | required | Path to SelVA model weights directory |
|
||||||
|
| `--rank` | `16` | LoRA rank — higher = more capacity, more VRAM |
|
||||||
|
| `--alpha` | `rank` | LoRA alpha scaling. Default (= rank) means scale = 1.0 |
|
||||||
|
| `--target` | `attn.qkv` | Which layers to adapt. Add `linear1` for post-attention projections |
|
||||||
|
| `--lr` | `1e-4` | Learning rate |
|
||||||
|
| `--steps` | `2000` | Total training steps |
|
||||||
|
| `--warmup_steps` | `100` | Linear LR warmup steps |
|
||||||
|
| `--batch_size` | `4` | Clips per training step — higher is more stable, uses more VRAM |
|
||||||
|
| `--grad_accum` | `1` | Gradient accumulation steps (use when batch_size is already > 1) |
|
||||||
|
| `--save_every` | `500` | Save a checkpoint every N steps |
|
||||||
|
| `--resume` | `None` | Path to a step checkpoint to resume from (e.g. `lora_output/adapter_step04000.pt`) |
|
||||||
|
| `--precision` | `bf16` | Mixed precision: `bf16`, `fp16`, `fp32` |
|
||||||
|
| `--seed` | `42` | Random seed |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Step 3 — Load the adapter in ComfyUI
|
||||||
|
|
||||||
|
Connect **SelVA LoRA Loader** between the model loader and the sampler:
|
||||||
|
|
||||||
|
```
|
||||||
|
SelVA Model Loader → SelVA LoRA Loader → SelVA Sampler
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Important:** Wire the LoRA Loader output to the **Sampler**, not the Feature Extractor. The LoRA adapts the generator which only runs in the Sampler.
|
||||||
|
|
||||||
|
| Input | Description |
|
||||||
|
|---|---|
|
||||||
|
| `model` | SELVA_MODEL from the model loader |
|
||||||
|
| `adapter_path` | Path to `adapter_final.pt` or any `adapter_stepXXXXX.pt` |
|
||||||
|
| `strength` | 0.0 = adapter disabled, 1.0 = full strength, >1.0 = exaggerated |
|
||||||
|
|
||||||
|
The loader reads rank, alpha, and target layers from the metadata embedded in the `.pt` file — no need to set them manually.
|
||||||
|
|
||||||
|
> The base model is not modified. The loader returns a shallow copy with a deep-copied generator so the original stays intact.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tuning Guide
|
||||||
|
|
||||||
|
### Clip length
|
||||||
|
|
||||||
|
The model has a **fixed input duration of 8 seconds** for all variants (both 16k and 44k). This is not a parameter you can change.
|
||||||
|
|
||||||
|
- Audio shorter than 8 s is **zero-padded** (silence appended). The model will learn the sound but may also learn silence as part of the pattern — keep in mind for very short sounds.
|
||||||
|
- Audio longer than 8 s is **trimmed** at 8 s. Content beyond that is lost.
|
||||||
|
- Video shorter than 8 s has its **last frame repeated** to fill the clip.
|
||||||
|
|
||||||
|
**Practical recommendations:**
|
||||||
|
|
||||||
|
| Sound type | Clip strategy |
|
||||||
|
|---|---|
|
||||||
|
| Continuous sound (rain, engine, wind) | 8 s recordings, as many positions in the audio as possible |
|
||||||
|
| Single event < 2 s (click, bark, knock) | Center the event — pad deliberately with silence before/after, or loop the event 2–3 times per clip |
|
||||||
|
| Repeating event (footsteps, dripping) | Record full 8 s with natural repetition at the intended cadence |
|
||||||
|
| Sound with a clear onset (explosion, splash) | Put the onset at ~1–2 s from the start, not at 0 s — gives the model context |
|
||||||
|
|
||||||
|
> **Tip:** When extracting features in ComfyUI, set `duration` to 0 to use the full video length up to 8 s. Clips longer than 8 s are automatically clamped.
|
||||||
|
|
||||||
|
### How many clips do I need?
|
||||||
|
|
||||||
|
The table below gives a rough scaling guide. Quality and diversity of recordings matter more than raw count.
|
||||||
|
|
||||||
|
| Dataset size | Scenario | Expected result |
|
||||||
|
|---|---|---|
|
||||||
|
| **5–10 clips** | Quick test / proof of concept | May work if the model already partially knows the sound; often underfits |
|
||||||
|
| **15–30 clips** | Fine-tuning a sound the model knows but gets wrong | Good starting point — covers the main variations |
|
||||||
|
| **30–60 clips** | Teaching a new but acoustically simple sound class | Reliable convergence with default hyperparameters |
|
||||||
|
| **60–150 clips** | Unusual or complex sounds, strong style shift | Needed for stable generalization across video contexts |
|
||||||
|
| **150–300 clips** | Sounds the model has never encountered | Required to avoid overfitting; increase rank to 32 |
|
||||||
|
| **300+** | Large-scale domain shift | Consider also targeting `linear1` in addition to `attn.qkv` |
|
||||||
|
|
||||||
|
**Diversity beats quantity.** Ten clips of a dog barking in different environments (indoors, outdoors, distant, close) train better than fifty clips of the same recording. Vary: distance, room acoustics, intensity, speed.
|
||||||
|
|
||||||
|
### Batch size
|
||||||
|
|
||||||
|
| Batch size | VRAM (large_44k) | Use case |
|
||||||
|
|---|---|---|
|
||||||
|
| `1` | ~9 GB | Minimal VRAM, noisy gradients |
|
||||||
|
| `4` | ~12 GB | Good default — stable gradients, reasonable speed |
|
||||||
|
| `8` | ~15 GB | Better convergence on larger datasets |
|
||||||
|
| `16` | ~20 GB | Best gradient quality when VRAM allows |
|
||||||
|
|
||||||
|
Higher batch size gives smoother loss curves and faster convergence. If you have headroom, prefer larger batches over more steps.
|
||||||
|
|
||||||
|
**Observed results:** batch 16 reaches the same loss in ~2600 steps that batch 1 needed 8000+ steps to reach, with a near-perfectly smooth curve. On a 24 GB GPU, batch 16 is the recommended default for `large_44k`.
|
||||||
|
|
||||||
|
### Rank
|
||||||
|
|
||||||
|
| Rank | Use case |
|
||||||
|
|---|---|
|
||||||
|
| `8` | Fine details on a sound the model already knows well |
|
||||||
|
| `16` | Default — good balance of capacity and VRAM |
|
||||||
|
| `32` | Harder sounds or larger style shifts (30+ clips recommended) |
|
||||||
|
|
||||||
|
Higher rank increases VRAM usage and overfitting risk on small datasets.
|
||||||
|
|
||||||
|
### Steps
|
||||||
|
|
||||||
|
With `batch_size=4` as the default, these are rough guidelines:
|
||||||
|
|
||||||
|
| Dataset size | Recommended steps |
|
||||||
|
|---|---|
|
||||||
|
| 10–20 clips | 2000–4000 |
|
||||||
|
| 20–50 clips | 4000–8000 |
|
||||||
|
| 50+ clips | 6000–15000 |
|
||||||
|
|
||||||
|
Watch the loss curve — if the smoothed line has been flat for 2000+ steps, training has converged for your dataset size. Adding more clips will let it go lower.
|
||||||
|
|
||||||
|
### Learning rate
|
||||||
|
|
||||||
|
`1e-4` is the recommended default for any batch size. If training is unstable (loss spikes in the first 200 steps), try `5e-5`. If convergence is very slow, try `2e-4`.
|
||||||
|
|
||||||
|
Warmup (default 100 steps) ramps the LR from 0 to avoid instability at the start.
|
||||||
|
|
||||||
|
### Target layers
|
||||||
|
|
||||||
|
`attn.qkv` (default) adapts only the self-attention QKV projections. This is the recommended starting point for all dataset sizes.
|
||||||
|
|
||||||
|
Add `linear1` to also adapt post-attention projections for large-scale domain shifts or when `attn.qkv` alone plateaus too early:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--target attn.qkv linear1
|
||||||
|
```
|
||||||
|
|
||||||
|
Only add `linear1` once you have 150+ clips — it doubles the adapted parameter count and overfits faster on small datasets.
|
||||||
|
|
||||||
|
### Adapter strength at inference
|
||||||
|
|
||||||
|
| Strength | Effect |
|
||||||
|
|---|---|
|
||||||
|
| `0.5–0.7` | Conservative — blends adapter with base model, less noise |
|
||||||
|
| `1.0` | Full adapter strength (default) |
|
||||||
|
| `>1.0` | Exaggerated effect, may introduce artifacts |
|
||||||
|
|
||||||
|
If the generated audio has noticeable white noise or artifacts, lower the strength to `0.6–0.7` before adjusting anything else. Also try lowering CFG scale in the Sampler.
|
||||||
|
|
||||||
|
### Loss interpretation
|
||||||
|
|
||||||
|
A typical loss curve:
|
||||||
|
- Starts around `0.8–1.0`
|
||||||
|
- Should reach `0.55–0.65` after convergence on a clean sound class with 10–30 clips
|
||||||
|
- Below `0.4` indicates strong learning — usually requires 50+ diverse clips
|
||||||
|
- Below `0.1` on a small dataset means overfitting
|
||||||
|
|
||||||
|
The smoothed curve flattening for 2000+ steps is the clearest sign to stop or add more data.
|
||||||
|
|
||||||
|
### Precision
|
||||||
|
|
||||||
|
Use `bf16` on Ampere+ GPUs (RTX 3xxx/4xxx, A100). Fall back to `fp16` on older GPUs. `fp32` is only needed for debugging — 2× more VRAM.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Output files
|
||||||
|
|
||||||
|
```
|
||||||
|
lora_output/my_sound/
|
||||||
|
adapter_step00500.pt ← step checkpoint (includes optimizer state for resume)
|
||||||
|
adapter_step01000.pt
|
||||||
|
...
|
||||||
|
adapter_final.pt ← final adapter with embedded metadata (inference only)
|
||||||
|
meta.json ← human-readable metadata
|
||||||
|
sample_step00500.wav ← quick eval sample at each checkpoint
|
||||||
|
loss_raw.png ← raw loss curve
|
||||||
|
loss_smoothed.png ← EMA-smoothed loss curve
|
||||||
|
```
|
||||||
|
|
||||||
|
`adapter_final.pt` format:
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"state_dict": { "blocks.0.attn.qkv.lora_A": ..., ... },
|
||||||
|
"meta": {
|
||||||
|
"variant": "large_44k",
|
||||||
|
"rank": 16,
|
||||||
|
"alpha": 16.0,
|
||||||
|
"target": ["attn.qkv"],
|
||||||
|
"steps": 2000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Step checkpoints (e.g. `adapter_step01000.pt`) additionally contain `optimizer` and `scheduler` state for resuming.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
**`No layers matched target=...`**
|
||||||
|
The `--target` suffixes do not match any layer names. The default `attn.qkv` targets `SelfAttention.qkv` in all transformer blocks. If you changed `--target`, verify the layer names with `model.named_modules()`.
|
||||||
|
|
||||||
|
**`No .npz files found in ...`**
|
||||||
|
The `--data_dir` path is wrong or no `.npz` files were extracted there yet. Run SelVA Feature Extractor in ComfyUI first with the matching `cache_dir`.
|
||||||
|
|
||||||
|
**`No audio file found for clip.npz`**
|
||||||
|
Place an audio file with the exact same stem next to the `.npz`: `clip.wav`, `clip.flac`, etc.
|
||||||
|
|
||||||
|
**The sound is audible but there is white noise on top**
|
||||||
|
Lower the adapter strength to `0.6–0.7` in SelVA LoRA Loader. Also try lowering CFG scale in the Sampler. This is normal when the model hasn't fully converged — more clips and more steps will reduce it.
|
||||||
|
|
||||||
|
**LoRA appears to have no effect**
|
||||||
|
Make sure the SelVA LoRA Loader output is wired to the **Sampler** input, not the Feature Extractor. The Feature Extractor does not use the generator.
|
||||||
|
|
||||||
|
**Loss does not decrease**
|
||||||
|
- Increase `batch_size` for more stable gradients.
|
||||||
|
- Try a higher learning rate (`2e-4`) or check that warmup isn't too long.
|
||||||
|
- Check that the audio files are clean and actually contain the target sound.
|
||||||
|
- Check that the `.npz` features were extracted with a relevant prompt.
|
||||||
|
|
||||||
|
**Loss explodes or NaN**
|
||||||
|
- Lower the learning rate (`5e-5`).
|
||||||
|
- Make sure audio is normalized to `[-1, 1]`. PCM files with 16-bit integer encoding may need to be converted: `ffmpeg -i input.wav -ar 44100 -sample_fmt s16 output.wav`
|
||||||
|
|
||||||
|
**Loss plateaus early (above 0.7)**
|
||||||
|
Dataset is the bottleneck. Add more clips — diversity matters more than quantity.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Observations (work in progress)
|
||||||
|
|
||||||
|
These are empirical findings from ongoing experiments. They will be promoted to the main guide once more validated.
|
||||||
|
|
||||||
|
### Precision and batch size
|
||||||
|
|
||||||
|
| Config | Smoothed loss at step 2000 | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| bf16 batch 1 | ~0.73 | Noisy gradients, slow |
|
||||||
|
| bf16 batch 16 | ~0.65 | Stable, plateaued around step 6000–8000 at ~0.59 |
|
||||||
|
| bf16 batch 16 logit_normal | ~0.47 | Lower loss floor, similar or marginally better audio |
|
||||||
|
| fp32 batch 32 | ~0.58 | Matches bf16 batch 16 at step 6000 already at step 2000 |
|
||||||
|
|
||||||
|
**Key finding:** fp32 batch 32 converges to the same perceptual quality point in ~2000 steps that bf16 batch 16 needs 6000+ steps to reach. However, fp32 batch 32 continues descending well past that point on small datasets (10 clips), eventually overfitting. **Stop fp32 batch 32 around step 2000 on a 10-clip dataset** — later checkpoints sound worse despite lower loss.
|
||||||
|
|
||||||
|
**Lower loss ≠ better audio.** Once overfitting begins the model memorizes training clips rather than generalizing to new video inputs. Test intermediate checkpoints (e.g. step 500, 1000, 2000) to find the perceptual sweet spot.
|
||||||
|
|
||||||
|
### logit_normal vs uniform
|
||||||
|
|
||||||
|
logit_normal consistently reaches a lower loss floor than uniform. However perceptual improvement is dataset-dependent — on 10 clips the difference is marginal. May be more impactful with larger datasets. No conclusion yet.
|
||||||
|
|
||||||
|
### White noise
|
||||||
|
|
||||||
|
Residual white noise on generated audio is primarily a **dataset** problem, not a training one. Observed with all configs on 10 clips. Likely causes:
|
||||||
|
- Too few clips for the model to confidently predict the target sound
|
||||||
|
- Imprecise extraction prompts producing unfocused sync features
|
||||||
|
- Missing mask when multiple objects are in frame
|
||||||
|
|
||||||
|
CFG scale amplifies any adapter noise bias. Reducing CFG to 3.0–3.5 or adapter strength to 0.6–0.7 helps at inference.
|
||||||
@@ -1,156 +1,134 @@
|
|||||||
# ComfyUI-PrismAudio
|
# ComfyUI-SelVA
|
||||||
|
|
||||||
Custom nodes for [PrismAudio](https://huggingface.co/FunAudioLLM/PrismAudio) (ICLR 2026) — video-to-audio and text-to-audio generation using decomposed Chain-of-Thought reasoning with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
Custom nodes for [SelVA](https://github.com/jnwnlee/selva) — video-to-audio generation driven by text prompts. SelVA conditions audio synthesis on both visual content and natural language, letting you describe *what* sounds to generate rather than just *when*.
|
||||||
|
|
||||||
## Installation
|
Built on [MMAudio](https://github.com/hkchengrex/MMAudio) with a TextSynchformer encoder that injects text guidance directly into the visual sync stream.
|
||||||
|
|
||||||
Clone into your ComfyUI custom nodes directory:
|
---
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ComfyUI/custom_nodes
|
|
||||||
git clone https://github.com/Ethanfel/ComfyUI-Prismaudio.git ComfyUI-PrismAudio
|
|
||||||
pip install -r ComfyUI-PrismAudio/requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
**flash-attn** is optional — detected at runtime, falls back to PyTorch SDPA if unavailable.
|
|
||||||
|
|
||||||
## Nodes
|
## Nodes
|
||||||
|
|
||||||
### PrismAudio Model Loader
|
### SelVA Model Loader
|
||||||
|
|
||||||
Loads the DiT diffusion model and VAE. Auto-downloads weights from HuggingFace on first use.
|
Loads the generator, TextSynchformer encoder, and all feature utilities (CLIP, T5, Synchformer, VAE). Weights are auto-downloaded from HuggingFace on first use.
|
||||||
|
|
||||||
| Input | Options | Description |
|
| Input | Options | Description |
|
||||||
|-------|---------|-------------|
|
|-------|---------|-------------|
|
||||||
| `precision` | auto / fp32 / fp16 / bf16 | DiT and conditioner dtype. VAE is always fp32. |
|
| `variant` | small_16k / small_44k / medium_44k / large_44k | Model size and output sample rate |
|
||||||
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management. |
|
| `precision` | bf16 / fp16 / fp32 | Compute dtype |
|
||||||
|
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management |
|
||||||
|
|
||||||
|
**Output:** `model` (SELVA_MODEL)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### PrismAudio Feature Extractor
|
### SelVA Feature Extractor
|
||||||
|
|
||||||
Extracts video features (VideoPrism LvT, Synchformer) and text features (T5-Gemma) from a video in a subprocess. Results are cached on disk.
|
Extracts CLIP visual features and text-guided sync features from a video. Results are cached on disk — re-running with the same inputs is instant.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
|
| `model` | From SelVA Model Loader |
|
||||||
| `video` | IMAGE tensor from any ComfyUI video loader |
|
| `video` | IMAGE tensor from any ComfyUI video loader |
|
||||||
| `caption_cot` | Chain-of-thought description of the audio scene |
|
| `prompt` | Text description of the audio to generate |
|
||||||
| `video_info` | *(optional)* `VHS_VIDEOINFO` from VHS LoadVideo — sets fps automatically |
|
| `video_info` | *(optional)* VHS_VIDEOINFO from VHS LoadVideo — sets fps automatically |
|
||||||
| `fps` | Source fps — ignored if `video_info` is connected |
|
| `fps` | Source fps — ignored if `video_info` is connected |
|
||||||
| `python_env` | `managed_env` (auto-created isolated venv, recommended) or `comfyui_env` (current Python, see warning below) |
|
| `duration` | Override clip duration in seconds. `0` = infer from video length |
|
||||||
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir. |
|
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir |
|
||||||
| `hf_token` | HuggingFace token for gated models. Prefer `HF_TOKEN` env var instead. |
|
| `mask` | *(optional)* Segmentation mask `[T,H,W]` float [0,1] — static (1 frame) or per-frame |
|
||||||
|
| `mask_strength` | Background suppression strength. `1.0` = full neutral fill, `0.0` = no effect |
|
||||||
|
| `mask_clip` | Apply mask to CLIP features (384px path). Disable to let CLIP see the full scene |
|
||||||
|
| `mask_sync` | Apply mask to TextSynchformer sync features (224px path) |
|
||||||
|
|
||||||
**Outputs:** `features` (PRISMAUDIO_FEATURES), `fps` (FLOAT)
|
**Outputs:** `features` (SELVA_FEATURES), `fps` (FLOAT), `prompt` (STRING)
|
||||||
|
|
||||||
**`managed_env`** auto-creates a venv at `_extract_env/` inside the plugin directory on first use and installs JAX, TF, VideoPrism, and Synchformer. This takes several minutes the first time.
|
Connect `prompt` output to the Sampler's `prompt` input to avoid entering it twice.
|
||||||
|
|
||||||
**`comfyui_env`** uses the current ComfyUI Python — JAX/TF/videoprism must already be installed. Installing them into the ComfyUI environment may conflict with existing packages.
|
#### Masking
|
||||||
|
|
||||||
|
Connect a segmentation mask (SAM2, Grounding DINO+SAM, or any ComfyUI mask node) to isolate a specific object's motion before encoding. Background pixels are filled with a neutral value (0.5) rather than zeroed — this keeps them in-distribution for CLIP and maps to exactly 0 after sync's `[-1,1]` normalization, minimising the influence of background motion on the generated audio.
|
||||||
|
|
||||||
|
Use `mask_sync=true, mask_clip=false` if you want sync features focused on the target object while CLIP still sees the full scene for broader context. Changing any mask parameter correctly busts the feature cache.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### PrismAudio Feature Loader
|
### SelVA Sampler
|
||||||
|
|
||||||
Loads a pre-computed `.npz` feature file. Use this to re-use extracted features without re-running the extractor.
|
Generates audio from video features. Runs the rectified flow ODE with classifier-free guidance.
|
||||||
|
|
||||||
| Input | Description |
|
| Input | Description |
|
||||||
|-------|-------------|
|
|-------|-------------|
|
||||||
| `npz_path` | Path to a `.npz` file produced by the Feature Extractor |
|
| `model` | From SelVA Model Loader |
|
||||||
|
| `features` | From SelVA Feature Extractor |
|
||||||
---
|
| `prompt` | Text description — leave empty to use the prompt stored in features |
|
||||||
|
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
|
||||||
### PrismAudio Sampler
|
| `duration` | Audio duration in seconds. `0` = use duration from features |
|
||||||
|
| `steps` | Sampling steps (default: 25) |
|
||||||
Video-to-audio generation. Takes model + features, produces AUDIO.
|
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
|
||||||
|
|
||||||
| Input | Description |
|
|
||||||
|-------|-------------|
|
|
||||||
| `model` | From Model Loader |
|
|
||||||
| `features` | From Feature Extractor or Feature Loader |
|
|
||||||
| `duration` | Audio duration in seconds. Set to `0` to use the video duration from features automatically. |
|
|
||||||
| `steps` | Sampling steps (default: 100) |
|
|
||||||
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
|
||||||
| `seed` | RNG seed |
|
| `seed` | RNG seed |
|
||||||
|
| `normalize` | Peak-normalize output to [-1, 1] (default: true) |
|
||||||
|
|
||||||
|
**Output:** `AUDIO`
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### PrismAudio Text Only
|
## Workflow
|
||||||
|
|
||||||
Text-to-audio generation without video. Uses the T5-Gemma encoder.
|
```
|
||||||
|
VHS LoadVideo ──► SelVA Feature Extractor ──────────────────────► SelVA Sampler ──► Save Audio
|
||||||
|
│ (video_info) ─► (fps auto) ▲
|
||||||
|
│ (features) ────────────────────────────────────►│
|
||||||
|
│ (prompt) ──────────────────────────────────────►│
|
||||||
|
```
|
||||||
|
|
||||||
| Input | Description |
|
Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction.
|
||||||
|-------|-------------|
|
|
||||||
| `model` | From Model Loader |
|
|
||||||
| `text_prompt` | Chain-of-thought audio scene description. Longer, more detailed prompts produce better results. |
|
|
||||||
| `duration` | Audio duration in seconds |
|
|
||||||
| `steps` | Sampling steps (default: 100) |
|
|
||||||
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
|
||||||
| `seed` | RNG seed |
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Workflows
|
## Installation
|
||||||
|
|
||||||
### Video-to-Audio
|
```bash
|
||||||
|
cd ComfyUI/custom_nodes
|
||||||
```
|
git clone https://github.com/Ethanfel/ComfyUI-SelVA.git
|
||||||
VHS LoadVideo ──► PrismAudio Feature Extractor ──► PrismAudio Sampler ──► Save Audio
|
pip install -r ComfyUI-SelVA/requirements.txt
|
||||||
(video_info) ──────────────────► (fps auto)
|
|
||||||
(features) ────────────────────► (features)
|
|
||||||
duration=0 ─────────────────────► (auto from features)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Pre-computed Features
|
---
|
||||||
|
|
||||||
```
|
## Model Weights
|
||||||
PrismAudio Feature Loader (.npz) ──► PrismAudio Sampler ──► Save Audio
|
|
||||||
```
|
|
||||||
|
|
||||||
### Text-to-Audio
|
Weights are auto-downloaded to `ComfyUI/models/selva/` on first load. No manual setup required.
|
||||||
|
|
||||||
```
|
|
||||||
PrismAudio Text Only ──► Save Audio
|
|
||||||
```
|
|
||||||
|
|
||||||
## HuggingFace Authentication
|
|
||||||
|
|
||||||
Required for T5-Gemma (gated model) and PrismAudio weights.
|
|
||||||
|
|
||||||
1. Visit <https://huggingface.co/FunAudioLLM/PrismAudio> and accept the license.
|
|
||||||
2. Authenticate via one of:
|
|
||||||
- **Environment variable:** `export HF_TOKEN=hf_...`
|
|
||||||
- **CLI login:** `huggingface-cli login`
|
|
||||||
|
|
||||||
There is no `hf_token` widget on the main nodes by design — ComfyUI saves all STRING values to workflow JSON, which would expose your token. The Feature Extractor has an `hf_token` input as a convenience but using `HF_TOKEN` env var is preferred.
|
|
||||||
|
|
||||||
## Model Files
|
|
||||||
|
|
||||||
Weights are auto-downloaded to `ComfyUI/models/prismaudio/`:
|
|
||||||
|
|
||||||
| File | Size | Description |
|
| File | Size | Description |
|
||||||
|------|------|-------------|
|
|------|------|-------------|
|
||||||
| `prismaudio.ckpt` | ~2.7 GB | Diffusion model (DiT) |
|
| `video_enc_sup_5.pth` | ~300 MB | TextSynchformer encoder |
|
||||||
| `vae.ckpt` | ~2.5 GB | Stable Audio 2.0 VAE |
|
| `generator_small_16k_sup_5.pth` | ~340 MB | Small generator, 16 kHz output |
|
||||||
| `synchformer_state_dict.pth` | ~950 MB | Synchformer visual encoder |
|
| `generator_small_44k_sup_5.pth` | ~340 MB | Small generator, 44.1 kHz output |
|
||||||
|
| `generator_medium_44k_sup_5.pth` | ~860 MB | Medium generator, 44.1 kHz output |
|
||||||
|
| `generator_large_44k_sup_5.pth` | ~2.0 GB | Large generator, 44.1 kHz output |
|
||||||
|
| `v1-16.pth` | ~1.1 GB | VAE for 16 kHz |
|
||||||
|
| `v1-44.pth` | ~1.1 GB | VAE for 44.1 kHz |
|
||||||
|
| `best_netG.pt` | ~90 MB | BigVGAN vocoder for 16 kHz |
|
||||||
|
| `synchformer_state_dict.pth` | ~950 MB | Synchformer (shared with PrismAudio if present) |
|
||||||
|
|
||||||
T5-Gemma and VideoPrism LvT are cached in `~/.cache/huggingface/`.
|
CLIP (DFN5B-ViT-H-14-384) and T5 (flan-t5-base) are downloaded automatically from HuggingFace to `~/.cache/huggingface/`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## VRAM Requirements
|
## VRAM Requirements
|
||||||
|
|
||||||
| VRAM | Recommended settings |
|
| VRAM | Recommended settings |
|
||||||
|------|----------------------|
|
|------|----------------------|
|
||||||
| 24 GB+ | `keep_in_vram`, any precision |
|
| 24 GB+ | `keep_in_vram`, any variant |
|
||||||
| 12–24 GB | `offload_to_cpu`, bf16/fp16 |
|
| 12–24 GB | `offload_to_cpu`, medium or smaller |
|
||||||
| 8–12 GB | `offload_to_cpu`, fp16 |
|
| 8–12 GB | `offload_to_cpu`, small variant, fp16 |
|
||||||
| < 8 GB | May work with `offload_to_cpu` + fp16 |
|
|
||||||
|
|
||||||
## Troubleshooting
|
The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available, otherwise `offload_to_cpu`.
|
||||||
|
|
||||||
- **Gated model errors** — Accept the license at <https://huggingface.co/FunAudioLLM/PrismAudio> and set `HF_TOKEN`.
|
---
|
||||||
- **VRAM errors** — Switch `offload_strategy` to `offload_to_cpu` and/or use `fp16` precision.
|
|
||||||
- **Feature extraction fails** — Ensure `synchformer_state_dict.pth` is in `models/prismaudio/`. On first run with `managed_env`, installation takes several minutes.
|
|
||||||
- **flash-attn** — Optional. Auto-detected at runtime; falls back to PyTorch SDPA.
|
|
||||||
|
|
||||||
## Credits
|
## Credits
|
||||||
|
|
||||||
PrismAudio by [FunAudioLLM](https://github.com/FunAudioLLM) (ICLR 2026). [Model & weights](https://huggingface.co/FunAudioLLM/PrismAudio).
|
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
|
||||||
|
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
|
||||||
|
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
|
||||||
|
|||||||
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
|
ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
|
||||||
"""
|
"""
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -1,337 +0,0 @@
|
|||||||
"""
|
|
||||||
PrismAudio feature extraction utilities.
|
|
||||||
|
|
||||||
Implements FeaturesUtils used by scripts/extract_features.py to extract:
|
|
||||||
- Text features via T5-Gemma (transformers)
|
|
||||||
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
|
|
||||||
- Sync features via Synchformer visual encoder (PyTorch)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class FeaturesUtils:
|
|
||||||
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
|
|
||||||
self.device = device or torch.device("cpu")
|
|
||||||
self._t5_tokenizer = None
|
|
||||||
self._t5_encoder = None
|
|
||||||
self._vp_model = None
|
|
||||||
self._vp_state = None
|
|
||||||
self._vp_text_tokenizer = None
|
|
||||||
self._sync_model = None
|
|
||||||
|
|
||||||
self._synchformer_ckpt = synchformer_ckpt
|
|
||||||
self._load_synchformer()
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# T5-Gemma text encoding
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _ensure_t5(self):
|
|
||||||
if self._t5_encoder is not None:
|
|
||||||
return
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
||||||
model_id = "google/t5gemma-l-l-ul2-it"
|
|
||||||
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
|
|
||||||
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
||||||
self._t5_encoder = (
|
|
||||||
AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
|
||||||
.get_encoder()
|
|
||||||
.to(self.device)
|
|
||||||
.eval()
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode_t5_text(self, texts):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
texts: list of str
|
|
||||||
Returns:
|
|
||||||
Tensor [seq_len, 1024]
|
|
||||||
"""
|
|
||||||
self._ensure_t5()
|
|
||||||
tokens = self._t5_tokenizer(
|
|
||||||
texts, return_tensors="pt", padding=True
|
|
||||||
).to(self.device)
|
|
||||||
with torch.no_grad():
|
|
||||||
out = self._t5_encoder(**tokens)
|
|
||||||
# Move encoder off GPU to save VRAM
|
|
||||||
self._t5_encoder.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# VideoPrism video + text encoding (JAX)
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _ensure_videoprism(self):
|
|
||||||
if self._vp_model is not None:
|
|
||||||
return
|
|
||||||
from videoprism import models as vp
|
|
||||||
import jax
|
|
||||||
model_name = "videoprism_lvt_public_v1_large"
|
|
||||||
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
|
|
||||||
self._vp_model = vp.get_model(model_name)
|
|
||||||
self._vp_state = vp.load_pretrained_weights(model_name)
|
|
||||||
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
|
|
||||||
jax_dev = jax.devices()[0]
|
|
||||||
self._jax_forward = jax.jit(
|
|
||||||
lambda x, y, z: self._vp_model.apply(
|
|
||||||
self._vp_state, x, y, z, train=False, return_intermediate=True
|
|
||||||
),
|
|
||||||
device=jax_dev,
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
|
||||||
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
|
|
||||||
Returns:
|
|
||||||
global_video_features: Tensor [1, D]
|
|
||||||
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
|
|
||||||
global_text_features: Tensor [1, D]
|
|
||||||
"""
|
|
||||||
self._ensure_videoprism()
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from videoprism import models as vp
|
|
||||||
|
|
||||||
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
|
||||||
frames = clip_input.squeeze(0) # [T, C, H, W]
|
|
||||||
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
|
|
||||||
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
|
||||||
frames_np = frames.cpu().numpy().astype(np.float32)
|
|
||||||
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
|
||||||
|
|
||||||
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
|
|
||||||
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
|
|
||||||
|
|
||||||
# Joint video+text forward with intermediate outputs
|
|
||||||
video_embeddings, text_embeddings, outputs = self._jax_forward(
|
|
||||||
frames_jax, text_ids, text_paddings
|
|
||||||
)
|
|
||||||
|
|
||||||
# Per-frame features: [B, T, 1024] L2-normalized
|
|
||||||
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
|
|
||||||
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
|
|
||||||
|
|
||||||
# Global video embedding: [1024] → [1, 1024]
|
|
||||||
global_video = torch.from_numpy(
|
|
||||||
np.array(video_embeddings[0])
|
|
||||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
|
||||||
|
|
||||||
# Global text embedding: [1024] → [1, 1024]
|
|
||||||
global_text = torch.from_numpy(
|
|
||||||
np.array(text_embeddings[0])
|
|
||||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
|
||||||
|
|
||||||
return global_video, per_frame, global_text
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Synchformer sync feature encoding
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _load_synchformer(self):
|
|
||||||
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
|
|
||||||
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
|
|
||||||
|
|
||||||
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
|
|
||||||
if isinstance(state, dict) and "model" in state:
|
|
||||||
state_dict = state["model"]
|
|
||||||
else:
|
|
||||||
state_dict = state
|
|
||||||
|
|
||||||
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
|
|
||||||
self._sync_model.eval()
|
|
||||||
|
|
||||||
def encode_video_with_sync(self, sync_input):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
|
||||||
Returns:
|
|
||||||
sync_features: Tensor [num_segments, 768]
|
|
||||||
"""
|
|
||||||
if self._sync_model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"[FeaturesUtils] Synchformer checkpoint not loaded. "
|
|
||||||
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
|
|
||||||
)
|
|
||||||
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
|
|
||||||
with torch.no_grad():
|
|
||||||
return self._sync_model(frames)
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Synchformer visual encoder — TimeSformer-style ViT-B/16
|
|
||||||
# Architecture reverse-engineered from synchformer_state_dict.pth
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class _PatchEmbed(nn.Module):
|
|
||||||
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.proj(x).flatten(2).transpose(1, 2)
|
|
||||||
|
|
||||||
|
|
||||||
class _ViTAttn(nn.Module):
|
|
||||||
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
|
|
||||||
def __init__(self, dim=768, num_heads=12):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = dim // num_heads
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, N, D = x.shape
|
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv.unbind(0)
|
|
||||||
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
|
|
||||||
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
|
|
||||||
|
|
||||||
|
|
||||||
class _BlockMLP(nn.Module):
|
|
||||||
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
|
|
||||||
def __init__(self, dim=768, mlp_dim=3072):
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Linear(dim, mlp_dim)
|
|
||||||
self.fc2 = nn.Linear(mlp_dim, dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fc2(F.gelu(self.fc1(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class _TimeSformerBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Factorized space-time attention block.
|
|
||||||
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
|
|
||||||
"""
|
|
||||||
def __init__(self, dim=768, num_heads=12):
|
|
||||||
super().__init__()
|
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
|
||||||
self.attn = _ViTAttn(dim, num_heads)
|
|
||||||
self.norm3 = nn.LayerNorm(dim)
|
|
||||||
self.timeattn = _ViTAttn(dim, num_heads)
|
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
|
||||||
self.mlp = _BlockMLP(dim)
|
|
||||||
|
|
||||||
def forward(self, x, T):
|
|
||||||
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
|
|
||||||
x = x + self.attn(self.norm1(x))
|
|
||||||
# Temporal attention: for each spatial position, attend across T frames
|
|
||||||
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
|
|
||||||
xt = x.permute(1, 0, 2)
|
|
||||||
xt = xt + self.timeattn(self.norm3(xt))
|
|
||||||
x = xt.permute(1, 0, 2)
|
|
||||||
x = x + self.mlp(self.norm2(x))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class _SpatialAttnAgg(nn.Module):
|
|
||||||
"""
|
|
||||||
Aggregates 196 spatial patches → 1 feature per frame using a
|
|
||||||
TransformerEncoderLayer with a learnable CLS token.
|
|
||||||
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
|
|
||||||
"""
|
|
||||||
def __init__(self, dim=768, num_heads=12):
|
|
||||||
super().__init__()
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
|
||||||
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
|
||||||
self.linear1 = nn.Linear(dim, dim * 4)
|
|
||||||
self.linear2 = nn.Linear(dim * 4, dim)
|
|
||||||
self.norm1 = nn.LayerNorm(dim)
|
|
||||||
self.norm2 = nn.LayerNorm(dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: [T, 196, 768] — spatial patches (CLS stripped)
|
|
||||||
T = x.shape[0]
|
|
||||||
cls = self.cls_token.expand(T, -1, -1)
|
|
||||||
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
|
|
||||||
xn = self.norm1(x)
|
|
||||||
x = x + self.self_attn(xn, xn, xn)[0]
|
|
||||||
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
|
||||||
return x[:, 0, :] # [T, 768] — CLS per frame
|
|
||||||
|
|
||||||
|
|
||||||
class _SynchformerVisualEncoder(nn.Module):
|
|
||||||
"""
|
|
||||||
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
|
|
||||||
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, state_dict, device):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.segment_frames = 8
|
|
||||||
|
|
||||||
self.patch_embed = _PatchEmbed()
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
|
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
|
|
||||||
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
|
|
||||||
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
|
|
||||||
self.norm = nn.LayerNorm(768)
|
|
||||||
self.spatial_attn_agg = _SpatialAttnAgg()
|
|
||||||
|
|
||||||
# Load weights from vfeat_extractor.* prefix
|
|
||||||
prefix = "vfeat_extractor."
|
|
||||||
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
|
||||||
# Exclude 3D patch embed (we use 2D only)
|
|
||||||
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
|
|
||||||
missing, unexpected = self.load_state_dict(sub, strict=False)
|
|
||||||
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
|
|
||||||
if missing:
|
|
||||||
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
|
|
||||||
|
|
||||||
self.to(device)
|
|
||||||
|
|
||||||
def forward(self, frames):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
|
|
||||||
Returns:
|
|
||||||
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
|
|
||||||
"""
|
|
||||||
T = frames.shape[0]
|
|
||||||
seg = self.segment_frames
|
|
||||||
num_seg = max(1, T // seg)
|
|
||||||
T_aligned = num_seg * seg
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i in range(num_seg):
|
|
||||||
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
|
|
||||||
results.append(self._forward_segment(chunk))
|
|
||||||
return torch.cat(results, dim=0) # [T_aligned, 768]
|
|
||||||
|
|
||||||
def _forward_segment(self, x):
|
|
||||||
# x: [8, 3, 224, 224]
|
|
||||||
T = x.shape[0] # 8
|
|
||||||
|
|
||||||
# Patch embedding + CLS token
|
|
||||||
x = self.patch_embed(x) # [8, 196, 768]
|
|
||||||
cls = self.cls_token.expand(T, -1, -1)
|
|
||||||
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
|
|
||||||
|
|
||||||
# Positional + temporal embeddings
|
|
||||||
x = x + self.pos_embed # broadcast (1,197,768)
|
|
||||||
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
|
|
||||||
|
|
||||||
# Transformer blocks (factorized space-time)
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x, T)
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
|
|
||||||
# Aggregate spatial patches → 1 feature per frame
|
|
||||||
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
# ComfyUI-PrismAudio Design Document
|
|
||||||
|
|
||||||
**Date:** 2026-03-27
|
|
||||||
**Status:** Approved
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
ComfyUI-PrismAudio/
|
|
||||||
├── __init__.py # Node registration
|
|
||||||
├── nodes/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── model_loader.py # PrismAudioModelLoader
|
|
||||||
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
|
|
||||||
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
|
|
||||||
│ ├── sampler.py # PrismAudioSampler
|
|
||||||
│ ├── text_only.py # PrismAudioTextOnly
|
|
||||||
│ └── utils.py # Shared helpers
|
|
||||||
├── prismaudio_core/ # Extracted inference code from PrismAudio
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── configs/
|
|
||||||
│ │ └── prismaudio.json
|
|
||||||
│ ├── models/ # DiT, conditioners, autoencoders, etc.
|
|
||||||
│ ├── inference/ # sampling.py, generation.py
|
|
||||||
│ └── factory.py # create_model_from_config
|
|
||||||
├── scripts/
|
|
||||||
│ ├── extract_features.py # Standalone VideoPrism feature extraction
|
|
||||||
│ └── environment.yml # Conda env for extraction (JAX + TF)
|
|
||||||
├── requirements.txt # PyTorch-only deps (no JAX/TF)
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
## Nodes
|
|
||||||
|
|
||||||
### PrismAudioModelLoader
|
|
||||||
|
|
||||||
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
|
|
||||||
|
|
||||||
| Field | Type | Details |
|
|
||||||
|-------|------|---------|
|
|
||||||
| **Inputs** | | |
|
|
||||||
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
|
|
||||||
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
|
|
||||||
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
|
|
||||||
| **Output** | | |
|
|
||||||
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
|
|
||||||
|
|
||||||
**Token resolution order** (no widget — env/CLI only for security):
|
|
||||||
1. `HF_TOKEN` environment variable
|
|
||||||
2. `huggingface-cli login` cached token
|
|
||||||
3. None — fails on gated models with clear error message linking to license page
|
|
||||||
|
|
||||||
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
|
|
||||||
|
|
||||||
### PrismAudioFeatureLoader
|
|
||||||
|
|
||||||
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
|
|
||||||
|
|
||||||
| Field | Type | Details |
|
|
||||||
|-------|------|---------|
|
|
||||||
| **Inputs** | | |
|
|
||||||
| npz_path | STRING | Path to .npz file |
|
|
||||||
| **Output** | | |
|
|
||||||
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
|
|
||||||
|
|
||||||
### PrismAudioFeatureExtractor
|
|
||||||
|
|
||||||
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
|
|
||||||
|
|
||||||
| Field | Type | Details |
|
|
||||||
|-------|------|---------|
|
|
||||||
| **Inputs** | | |
|
|
||||||
| video | IMAGE | ComfyUI video frames tensor |
|
|
||||||
| caption_cot | STRING | CoT description text |
|
|
||||||
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
|
|
||||||
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
|
|
||||||
| **Output** | | |
|
|
||||||
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
|
|
||||||
|
|
||||||
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
|
|
||||||
|
|
||||||
### PrismAudioSampler
|
|
||||||
|
|
||||||
Main generation node — takes model + features, produces audio.
|
|
||||||
|
|
||||||
| Field | Type | Details |
|
|
||||||
|-------|------|---------|
|
|
||||||
| **Inputs** | | |
|
|
||||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
|
||||||
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
|
|
||||||
| cot_description | STRING | Multiline CoT text |
|
|
||||||
| duration | FLOAT | 1.0-30.0, defaults to video length |
|
|
||||||
| steps | INT | 1-100, default 24 |
|
|
||||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
|
||||||
| seed | INT | Controls noise generation |
|
|
||||||
| **Output** | | |
|
|
||||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
|
||||||
|
|
||||||
**Pipeline:**
|
|
||||||
1. Encode CoT text via T5-Gemma -> text_features
|
|
||||||
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
|
|
||||||
3. Compute latent_seq_len = round(44100 / 2048 * duration)
|
|
||||||
4. Generate noise [1, 64, latent_seq_len] from seed
|
|
||||||
5. Discrete Euler sampling (rectified flow) with CFG
|
|
||||||
6. VAE decode -> stereo waveform at 44100 Hz
|
|
||||||
7. Normalize to [-1, 1], return as AUDIO
|
|
||||||
|
|
||||||
### PrismAudioTextOnly
|
|
||||||
|
|
||||||
Text-to-audio without video input.
|
|
||||||
|
|
||||||
| Field | Type | Details |
|
|
||||||
|-------|------|---------|
|
|
||||||
| **Inputs** | | |
|
|
||||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
|
||||||
| text_prompt | STRING | Text description |
|
|
||||||
| duration | FLOAT | 1.0-30.0 |
|
|
||||||
| steps | INT | 1-100, default 24 |
|
|
||||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
|
||||||
| seed | INT | Controls noise generation |
|
|
||||||
| **Output** | | |
|
|
||||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
|
||||||
|
|
||||||
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
|
|
||||||
|
|
||||||
## VRAM Management
|
|
||||||
|
|
||||||
Adaptive strategy using `comfy.model_management`:
|
|
||||||
|
|
||||||
| Available VRAM | Behavior |
|
|
||||||
|---|---|
|
|
||||||
| 24GB+ | Keep diffusion + VAE in VRAM |
|
|
||||||
| 12-24GB | Sequential offload between stages |
|
|
||||||
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
|
|
||||||
| <8GB | Warn user, attempt with aggressive offload + fp16 |
|
|
||||||
|
|
||||||
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
|
|
||||||
|
|
||||||
## Feature Extraction Paths
|
|
||||||
|
|
||||||
### Path 1: Pre-computed .npz (FeatureLoader)
|
|
||||||
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
|
|
||||||
|
|
||||||
### Path 2: Subprocess bridge (FeatureExtractor)
|
|
||||||
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
|
|
||||||
|
|
||||||
### Path 3: Text-only (TextOnly node)
|
|
||||||
No video features needed. T5-Gemma text encoding only (PyTorch-native).
|
|
||||||
|
|
||||||
## Dependencies
|
|
||||||
|
|
||||||
### ComfyUI environment (`requirements.txt`)
|
|
||||||
```
|
|
||||||
einops>=0.7.0
|
|
||||||
safetensors
|
|
||||||
huggingface_hub
|
|
||||||
transformers>=4.52.3
|
|
||||||
k-diffusion>=0.1.1
|
|
||||||
```
|
|
||||||
|
|
||||||
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
|
|
||||||
|
|
||||||
### Extraction environment (`scripts/environment.yml`)
|
|
||||||
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
|
|
||||||
|
|
||||||
## Model Files
|
|
||||||
|
|
||||||
Stored in `ComfyUI/models/prismaudio/`:
|
|
||||||
|
|
||||||
| File | Size | Source |
|
|
||||||
|------|------|--------|
|
|
||||||
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
|
|
||||||
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
|
|
||||||
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
|
|
||||||
|
|
||||||
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
|
|
||||||
|
|
||||||
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
|
|
||||||
|
|
||||||
## Design Decisions
|
|
||||||
|
|
||||||
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
|
|
||||||
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
|
|
||||||
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
|
|
||||||
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
|
|
||||||
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
|
|
||||||
File diff suppressed because it is too large
Load Diff
+6
-6
@@ -2,11 +2,11 @@ NODE_CLASS_MAPPINGS = {}
|
|||||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||||
|
|
||||||
_NODES = {
|
_NODES = {
|
||||||
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"),
|
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||||
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"),
|
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||||
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||||
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
"SelvaLoraLoader": (".selva_lora_loader", "SelvaLoraLoader", "SelVA LoRA Loader"),
|
||||||
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
"SelvaLoraTrainer": (".selva_lora_trainer", "SelvaLoraTrainer", "SelVA LoRA Trainer"),
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||||
@@ -16,4 +16,4 @@ for key, (module_path, class_name, display_name) in _NODES.items():
|
|||||||
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
||||||
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
print(f"[PrismAudio] Skipping {key}: {e}")
|
print(f"[SelVA] Skipping {key}: {e}")
|
||||||
|
|||||||
@@ -1,207 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import hashlib
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .utils import PRISMAUDIO_CATEGORY
|
|
||||||
from .feature_loader import PrismAudioFeatureLoader
|
|
||||||
|
|
||||||
# Managed venv created automatically when python_env is left as default
|
|
||||||
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
|
||||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
|
||||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
|
||||||
|
|
||||||
_EXTRACT_PACKAGES = [
|
|
||||||
"torch", "torchaudio", "torchvision",
|
|
||||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
|
||||||
"tensorflow-cpu>=2.16.0",
|
|
||||||
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
|
|
||||||
"jax[cuda13]", "flax",
|
|
||||||
"transformers", "decord", "einops", "numpy", "mediapy",
|
|
||||||
"git+https://github.com/google-deepmind/videoprism.git",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _pip_install(pip, *packages, label=None):
|
|
||||||
"""Install one or more packages with visible output; raise on failure."""
|
|
||||||
tag = label or packages[0]
|
|
||||||
print(f"[PrismAudio] installing {tag} ...", flush=True)
|
|
||||||
result = subprocess.run(
|
|
||||||
[pip, "install", "--progress-bar", "on"] + list(packages),
|
|
||||||
capture_output=False,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
|
|
||||||
"See pip output above for details."
|
|
||||||
)
|
|
||||||
print(f"[PrismAudio] {tag} OK", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_extract_env():
|
|
||||||
"""Create and populate the managed venv on first use."""
|
|
||||||
if os.path.exists(_MANAGED_PYTHON):
|
|
||||||
return _MANAGED_PYTHON
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
if os.path.exists(_MANAGED_VENV):
|
|
||||||
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
|
|
||||||
shutil.rmtree(_MANAGED_VENV)
|
|
||||||
|
|
||||||
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
|
|
||||||
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
|
|
||||||
|
|
||||||
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
|
|
||||||
|
|
||||||
print("[PrismAudio] Upgrading pip...", flush=True)
|
|
||||||
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
|
|
||||||
|
|
||||||
total = len(_EXTRACT_PACKAGES)
|
|
||||||
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
|
|
||||||
|
|
||||||
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
|
|
||||||
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
|
|
||||||
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
|
|
||||||
_pip_install(pip, pkg, label=label)
|
|
||||||
|
|
||||||
print("[PrismAudio] Feature-extraction env ready.", flush=True)
|
|
||||||
return _MANAGED_PYTHON
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_inputs(video_tensor, cot_text):
|
|
||||||
"""Create a hash of the inputs for caching."""
|
|
||||||
h = hashlib.sha256()
|
|
||||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
|
||||||
h.update(cot_text.encode())
|
|
||||||
return h.hexdigest()[:16]
|
|
||||||
|
|
||||||
|
|
||||||
def _save_frames_to_npy(video_tensor, output_path):
|
|
||||||
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
|
||||||
|
|
||||||
Lossless — avoids H.264 encode/decode roundtrip.
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
|
||||||
np.save(output_path, frames_np)
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioFeatureExtractor:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"video": ("IMAGE",),
|
|
||||||
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
|
|
||||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
|
|
||||||
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
|
|
||||||
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
|
||||||
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
|
|
||||||
RETURN_NAMES = ("features", "fps")
|
|
||||||
FUNCTION = "extract_features"
|
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
|
|
||||||
# Resolve fps from VHS video_info if connected
|
|
||||||
if video_info is not None:
|
|
||||||
fps = video_info["loaded_fps"]
|
|
||||||
|
|
||||||
# Resolve python binary
|
|
||||||
if python_env == "comfyui_env":
|
|
||||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
|
||||||
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
|
|
||||||
python_bin = sys.executable
|
|
||||||
else:
|
|
||||||
python_bin = _ensure_extract_env()
|
|
||||||
|
|
||||||
# Determine cache directory
|
|
||||||
if not cache_dir:
|
|
||||||
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Check cache
|
|
||||||
cache_hash = _hash_inputs(video, caption_cot)
|
|
||||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
|
||||||
if os.path.exists(cached_path):
|
|
||||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
|
||||||
loader = PrismAudioFeatureLoader()
|
|
||||||
features, = loader.load_features(cached_path)
|
|
||||||
return (features, float(fps))
|
|
||||||
|
|
||||||
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
|
||||||
import time
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
frames = video.shape[0]
|
|
||||||
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
|
||||||
tmp_video = tmp.name
|
|
||||||
_save_frames_to_npy(video, tmp_video)
|
|
||||||
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
|
||||||
|
|
||||||
# Build subprocess command
|
|
||||||
script_path = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
|
||||||
"scripts", "extract_features.py"
|
|
||||||
)
|
|
||||||
|
|
||||||
import folder_paths
|
|
||||||
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
|
|
||||||
if not os.path.exists(synchformer_ckpt):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
|
|
||||||
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
|
|
||||||
)
|
|
||||||
|
|
||||||
cmd = [
|
|
||||||
python_bin,
|
|
||||||
script_path,
|
|
||||||
"--video", tmp_video,
|
|
||||||
"--cot_text", caption_cot,
|
|
||||||
"--output", cached_path,
|
|
||||||
"--source_fps", str(fps),
|
|
||||||
"--synchformer_ckpt", synchformer_ckpt,
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build env: inherit current env, inject HF token if provided
|
|
||||||
import copy
|
|
||||||
env = copy.copy(os.environ)
|
|
||||||
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
|
|
||||||
if token:
|
|
||||||
env["HF_TOKEN"] = token
|
|
||||||
env["HUGGING_FACE_HUB_TOKEN"] = token
|
|
||||||
else:
|
|
||||||
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
|
|
||||||
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
|
|
||||||
|
|
||||||
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
|
|
||||||
try:
|
|
||||||
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
|
|
||||||
result = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
capture_output=False,
|
|
||||||
timeout=600, # 10 minute timeout
|
|
||||||
env=env,
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
|
|
||||||
"See output above for details."
|
|
||||||
)
|
|
||||||
print("[PrismAudio] Feature extraction subprocess finished successfully.")
|
|
||||||
finally:
|
|
||||||
if os.path.exists(tmp_video):
|
|
||||||
os.unlink(tmp_video)
|
|
||||||
|
|
||||||
# Load the extracted features
|
|
||||||
loader = PrismAudioFeatureLoader()
|
|
||||||
features, = loader.load_features(cached_path)
|
|
||||||
return (features, float(fps))
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from .utils import PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
|
||||||
# global_video_features and global_text_features are NOT consumed by any conditioner
|
|
||||||
# in the prismaudio.json config — they are unused.
|
|
||||||
REQUIRED_KEYS = [
|
|
||||||
"video_features",
|
|
||||||
"text_features",
|
|
||||||
"sync_features",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioFeatureLoader:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
|
||||||
RETURN_NAMES = ("features",)
|
|
||||||
FUNCTION = "load_features"
|
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
def load_features(self, npz_path):
|
|
||||||
if not os.path.exists(npz_path):
|
|
||||||
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
|
||||||
|
|
||||||
data = np.load(npz_path, allow_pickle=True)
|
|
||||||
|
|
||||||
features = {}
|
|
||||||
for key in REQUIRED_KEYS:
|
|
||||||
if key in data:
|
|
||||||
features[key] = torch.from_numpy(data[key]).float()
|
|
||||||
else:
|
|
||||||
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
|
||||||
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
|
||||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
|
||||||
if key == "sync_features":
|
|
||||||
features[key] = torch.zeros(8, 768)
|
|
||||||
else:
|
|
||||||
features[key] = torch.zeros(1, 1024)
|
|
||||||
|
|
||||||
# Load duration if present
|
|
||||||
if "duration" in data:
|
|
||||||
features["duration"] = float(data["duration"])
|
|
||||||
|
|
||||||
return (features,)
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import folder_paths
|
|
||||||
import comfy.model_management as mm
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
|
||||||
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
|
||||||
soft_empty_cache, resolve_hf_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
# HuggingFace repo for auto-download
|
|
||||||
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
|
||||||
REQUIRED_FILES = {
|
|
||||||
"diffusion": "prismaudio.ckpt",
|
|
||||||
"vae": "vae.ckpt",
|
|
||||||
"synchformer": "synchformer_state_dict.pth",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _download_if_missing(filename, model_dir, hf_token=None):
|
|
||||||
"""Download a model file from HuggingFace if not present locally."""
|
|
||||||
filepath = os.path.join(model_dir, filename)
|
|
||||||
if os.path.exists(filepath):
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
|
||||||
try:
|
|
||||||
downloaded = hf_hub_download(
|
|
||||||
repo_id=HF_REPO_ID,
|
|
||||||
filename=filename,
|
|
||||||
local_dir=model_dir,
|
|
||||||
token=hf_token or None,
|
|
||||||
)
|
|
||||||
return downloaded
|
|
||||||
except Exception as e:
|
|
||||||
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
|
||||||
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
|
||||||
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
|
||||||
) from e
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioModelLoader:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
register_model_folder()
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
|
||||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
|
||||||
RETURN_NAMES = ("model",)
|
|
||||||
FUNCTION = "load_model"
|
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
def load_model(self, precision, offload_strategy):
|
|
||||||
device = get_device()
|
|
||||||
dtype = determine_precision(precision, device)
|
|
||||||
strategy = determine_offload_strategy(offload_strategy)
|
|
||||||
token = resolve_hf_token()
|
|
||||||
model_dir = get_prismaudio_model_dir()
|
|
||||||
|
|
||||||
# Auto-download missing files
|
|
||||||
for key, filename in REQUIRED_FILES.items():
|
|
||||||
_download_if_missing(filename, model_dir, hf_token=token)
|
|
||||||
|
|
||||||
# Load config
|
|
||||||
config_path = os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
|
||||||
"prismaudio_core", "configs", "prismaudio.json"
|
|
||||||
)
|
|
||||||
with open(config_path) as f:
|
|
||||||
model_config = json.load(f)
|
|
||||||
|
|
||||||
# Create model from config
|
|
||||||
from prismaudio_core.factory import create_model_from_config
|
|
||||||
model = create_model_from_config(model_config)
|
|
||||||
|
|
||||||
# Load diffusion weights
|
|
||||||
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
|
||||||
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
|
||||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
|
||||||
if "state_dict" in diffusion_state:
|
|
||||||
diffusion_state = diffusion_state["state_dict"]
|
|
||||||
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
|
||||||
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
|
||||||
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
|
||||||
if diff_result.missing_keys:
|
|
||||||
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
|
||||||
if diff_result.unexpected_keys:
|
|
||||||
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
|
||||||
# Sample a few ckpt keys to verify prefix alignment
|
|
||||||
sample_keys = list(diffusion_state.keys())[:5]
|
|
||||||
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
|
||||||
|
|
||||||
# Load VAE weights separately
|
|
||||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
|
||||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
|
||||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
|
||||||
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
|
||||||
# Sample raw keys to see actual prefix
|
|
||||||
vae_sample_keys = list(vae_full_state.keys())[:8]
|
|
||||||
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
|
||||||
# Strip "autoencoder." prefix from keys
|
|
||||||
vae_state = {}
|
|
||||||
prefix = "autoencoder."
|
|
||||||
for k, v in vae_full_state.items():
|
|
||||||
if k.startswith(prefix):
|
|
||||||
vae_state[k[len(prefix):]] = v
|
|
||||||
else:
|
|
||||||
vae_state[k] = v
|
|
||||||
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
|
||||||
# Sample model keys to compare
|
|
||||||
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
|
||||||
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
|
||||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
|
||||||
# discriminator, loss modules, and EMA wrappers not present in the
|
|
||||||
# inference AudioAutoencoder — ignore those extra keys.
|
|
||||||
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
|
||||||
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
|
||||||
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
|
||||||
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
|
||||||
if vae_result.missing_keys:
|
|
||||||
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
|
||||||
|
|
||||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
|
||||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
|
||||||
model.model.to(dtype) # DiTWrapper
|
|
||||||
model.conditioner.to(dtype) # MultiConditioner
|
|
||||||
# model.pretransform stays in fp32
|
|
||||||
|
|
||||||
if strategy == "keep_in_vram":
|
|
||||||
model = model.to(device)
|
|
||||||
else:
|
|
||||||
model = model.to(get_offload_device())
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
return ({
|
|
||||||
"model": model,
|
|
||||||
"dtype": dtype,
|
|
||||||
"strategy": strategy,
|
|
||||||
"config": model_config,
|
|
||||||
"model_dir": model_dir,
|
|
||||||
},)
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
import torch
|
|
||||||
import comfy.model_management as mm
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
|
||||||
get_device, get_offload_device, soft_empty_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioSampler:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("PRISMAUDIO_MODEL",),
|
|
||||||
"features": ("PRISMAUDIO_FEATURES",),
|
|
||||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
|
||||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
|
||||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
diffusion = model["model"]
|
|
||||||
|
|
||||||
# Resolve duration: 0 means use video duration from features
|
|
||||||
if duration <= 0:
|
|
||||||
if "duration" not in features:
|
|
||||||
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
|
|
||||||
duration = features["duration"]
|
|
||||||
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
|
|
||||||
|
|
||||||
# Compute latent dimensions
|
|
||||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
|
||||||
|
|
||||||
# Note: no seq length config needed — the model adapts to input tensor shapes
|
|
||||||
# dynamically via its transformer architecture.
|
|
||||||
|
|
||||||
# Determine if video features are present (not all zeros)
|
|
||||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
|
||||||
|
|
||||||
video_feat = features["video_features"].to(device, dtype=dtype)
|
|
||||||
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
|
||||||
|
|
||||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
|
||||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
|
||||||
sample_meta = {
|
|
||||||
"video_features": video_feat,
|
|
||||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
|
||||||
"sync_features": sync_feat,
|
|
||||||
"video_exist": torch.tensor(has_video),
|
|
||||||
}
|
|
||||||
metadata = (sample_meta,)
|
|
||||||
|
|
||||||
# Move model to device if offloaded
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.model.to(device)
|
|
||||||
diffusion.conditioner.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
|
||||||
# Run conditioning
|
|
||||||
conditioning = diffusion.conditioner(metadata, device)
|
|
||||||
|
|
||||||
# Handle missing video: substitute learned empty embeddings
|
|
||||||
if not has_video:
|
|
||||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
|
||||||
|
|
||||||
# Assemble conditioning inputs for the DiT
|
|
||||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
|
||||||
|
|
||||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
|
||||||
gen_device = "cpu" if device.type == "mps" else device
|
|
||||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
|
||||||
noise = torch.randn(
|
|
||||||
[1, IO_CHANNELS, latent_length],
|
|
||||||
generator=generator,
|
|
||||||
device=gen_device,
|
|
||||||
).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
# Sample with progress bar
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
|
|
||||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
|
||||||
|
|
||||||
def on_step(info):
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
fakes = sample_discrete_euler(
|
|
||||||
diffusion.model,
|
|
||||||
noise,
|
|
||||||
steps,
|
|
||||||
callback=on_step,
|
|
||||||
**cond_inputs,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
batch_cfg=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
fakes_f = fakes.float()
|
|
||||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
|
||||||
|
|
||||||
# Offload diffusion model and conditioner before VAE decode
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.model.to(get_offload_device())
|
|
||||||
diffusion.conditioner.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
diffusion.pretransform.to(device)
|
|
||||||
|
|
||||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
|
||||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
|
||||||
audio = diffusion.pretransform.decode(fakes_f)
|
|
||||||
|
|
||||||
# Offload VAE
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.pretransform.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
|
||||||
audio = audio.float()
|
|
||||||
pre_norm_std = audio.std().item()
|
|
||||||
pre_norm_peak = audio.abs().max().item()
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
audio = (audio / peak).clamp(-1, 1)
|
|
||||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
|
||||||
|
|
||||||
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
|
||||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
|
||||||
|
|
||||||
|
|
||||||
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
|
||||||
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
|
|
||||||
|
|
||||||
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
|
|
||||||
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
|
|
||||||
near-zero activations, NOT the learned null signal the model was trained with.
|
|
||||||
|
|
||||||
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
|
||||||
"""
|
|
||||||
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
|
||||||
|
|
||||||
# Substitute video_features with learned empty_clip_feat
|
|
||||||
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
|
|
||||||
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
|
|
||||||
batch_size = conditioning['video_features'][0].shape[0]
|
|
||||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
|
||||||
conditioning['video_features'][0] = empty_expanded
|
|
||||||
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
|
|
||||||
|
|
||||||
# Substitute sync_features with learned empty_sync_feat
|
|
||||||
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
|
||||||
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
|
|
||||||
batch_size = conditioning['sync_features'][0].shape[0]
|
|
||||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
|
||||||
conditioning['sync_features'][0] = empty_expanded
|
|
||||||
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
|
||||||
@@ -0,0 +1,288 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
|
|
||||||
|
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_CLIP_FPS = 8
|
||||||
|
_SYNC_FPS = 25
|
||||||
|
|
||||||
|
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
|
||||||
|
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_frames(video, source_fps, target_fps, duration):
|
||||||
|
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
|
||||||
|
T = video.shape[0]
|
||||||
|
n_out = max(1, int(duration * target_fps))
|
||||||
|
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||||
|
return video[indices]
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_frames(frames, size):
|
||||||
|
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
||||||
|
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
||||||
|
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
|
||||||
|
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
||||||
|
"""
|
||||||
|
Apply a ComfyUI MASK to resized frames.
|
||||||
|
|
||||||
|
frames: [N, C, H, W] float [0,1]
|
||||||
|
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
|
||||||
|
source_fps: original video fps (for accurate temporal sampling)
|
||||||
|
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
|
||||||
|
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
|
||||||
|
|
||||||
|
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
|
||||||
|
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
|
||||||
|
"""
|
||||||
|
N, C, H, W = frames.shape
|
||||||
|
M = mask.shape[0]
|
||||||
|
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
|
||||||
|
if mask_f.shape[2] != H or mask_f.shape[3] != W:
|
||||||
|
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
|
||||||
|
|
||||||
|
# Temporal sampling — use same index formula as _sample_frames for accuracy
|
||||||
|
if M == 1:
|
||||||
|
mask_f = mask_f.expand(N, -1, -1, -1)
|
||||||
|
else:
|
||||||
|
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
|
||||||
|
mask_f = mask_f[indices] # [N, 1, H, W]
|
||||||
|
|
||||||
|
mask_f = mask_f.to(frames.device)
|
||||||
|
|
||||||
|
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
|
||||||
|
alpha = 1.0 - mask_strength * (1.0 - mask_f)
|
||||||
|
return frames * alpha + 0.5 * (1.0 - alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_named_path(cache_dir: str, name: str) -> str:
|
||||||
|
"""Return cache_dir/name.npz, incrementing to name_001.npz etc. if the file already exists."""
|
||||||
|
# Sanitize: replace path separators so the name stays inside cache_dir
|
||||||
|
name = name.replace("/", "_").replace("\\", "_").replace("\x00", "_")
|
||||||
|
i = 1
|
||||||
|
while True:
|
||||||
|
p = os.path.join(cache_dir, f"{name}_{i:03d}.npz")
|
||||||
|
if not os.path.exists(p):
|
||||||
|
return p
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
||||||
|
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
||||||
|
h = hashlib.sha256()
|
||||||
|
raw = video_tensor.cpu().numpy().tobytes()
|
||||||
|
n = len(raw)
|
||||||
|
chunk = 512 * 1024 # 512 KB per sample
|
||||||
|
h.update(raw[:chunk])
|
||||||
|
h.update(raw[n // 2: n // 2 + chunk])
|
||||||
|
h.update(raw[max(0, n - chunk):])
|
||||||
|
if mask is not None:
|
||||||
|
raw_m = mask.cpu().numpy().tobytes()
|
||||||
|
nm = len(raw_m)
|
||||||
|
chunk_m = 256 * 1024
|
||||||
|
h.update(raw_m[:chunk_m])
|
||||||
|
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
|
||||||
|
h.update(raw_m[max(0, nm - chunk_m):])
|
||||||
|
h.update(str(round(mask_strength, 4)).encode())
|
||||||
|
h.update(str(mask_clip).encode())
|
||||||
|
h.update(str(mask_sync).encode())
|
||||||
|
h.update(prompt.encode())
|
||||||
|
h.update(str(fps).encode())
|
||||||
|
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||||
|
h.update(variant.encode())
|
||||||
|
return h.hexdigest()[:32]
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaFeatureExtractor:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"video": ("IMAGE",),
|
||||||
|
"prompt": ("STRING", {
|
||||||
|
"default": "", "multiline": True,
|
||||||
|
"tooltip": "Describes the sounds to generate. Used to focus the visual sync features on motion relevant to the prompt — more specific prompts produce cleaner audio sync. Wire the prompt output directly to the Sampler so you only type it once.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"video_info": ("VHS_VIDEOINFO", {
|
||||||
|
"tooltip": "VHS_VIDEOINFO from VHS LoadVideo. Automatically sets the correct source fps — always connect this when loading video with VHS nodes.",
|
||||||
|
}),
|
||||||
|
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001,
|
||||||
|
"tooltip": "Source fps of the input video. Ignored when video_info is connected."}),
|
||||||
|
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
||||||
|
"cache_dir": ("STRING", {"default": "",
|
||||||
|
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
||||||
|
"name": ("STRING", {"default": "",
|
||||||
|
"tooltip": "Optional filename for the saved .npz (without extension). If provided, features are always saved with this name instead of a content hash — useful for building a named training dataset. Auto-increments: dog_bark → dog_bark_001 → dog_bark_002 if the file already exists. Leave empty to use the default content-hash cache."}),
|
||||||
|
"mask": ("MASK", {
|
||||||
|
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
||||||
|
}),
|
||||||
|
"mask_strength": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||||
|
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
|
||||||
|
}),
|
||||||
|
"mask_clip": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
|
||||||
|
}),
|
||||||
|
"mask_sync": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
||||||
|
RETURN_NAMES = ("features", "fps", "prompt")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Extracted feature bundle — connect to Sampler.",
|
||||||
|
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
||||||
|
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
||||||
|
)
|
||||||
|
OUTPUT_NODE = True # always execute: the node's purpose is saving .npz files to disk
|
||||||
|
FUNCTION = "extract_features"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||||
|
|
||||||
|
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||||
|
duration=0.0, cache_dir="", name="", mask=None,
|
||||||
|
mask_strength=1.0, mask_clip=True, mask_sync=True):
|
||||||
|
if video_info is not None:
|
||||||
|
fps = video_info["loaded_fps"]
|
||||||
|
|
||||||
|
T = video.shape[0]
|
||||||
|
if duration <= 0:
|
||||||
|
duration = T / fps
|
||||||
|
duration = min(duration, T / fps) # clamp to actual video length
|
||||||
|
|
||||||
|
if not prompt.strip():
|
||||||
|
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
if not cache_dir:
|
||||||
|
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if name.strip():
|
||||||
|
# Named mode: always extract and save to an incremented filename
|
||||||
|
cached_path = _resolve_named_path(cache_dir, name.strip())
|
||||||
|
else:
|
||||||
|
# Hash mode: skip extraction if identical inputs were already processed
|
||||||
|
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
||||||
|
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync)
|
||||||
|
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||||
|
if os.path.exists(cached_path):
|
||||||
|
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||||
|
cached = _load_cached(cached_path)
|
||||||
|
return (cached, float(fps), cached.get("prompt", prompt))
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
net_video_enc = model["video_enc"]
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(device)
|
||||||
|
net_video_enc.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||||
|
pbar = comfy.utils.ProgressBar(3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||||
|
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||||
|
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||||
|
if mask is not None and mask_clip:
|
||||||
|
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
||||||
|
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||||
|
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
|
||||||
|
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
|
||||||
|
|
||||||
|
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||||
|
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||||
|
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||||
|
if mask is not None and mask_sync:
|
||||||
|
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
||||||
|
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||||
|
if sync_frames.shape[0] < 16:
|
||||||
|
pad = 16 - sync_frames.shape[0]
|
||||||
|
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||||
|
# Normalize [0,1] → [-1,1]
|
||||||
|
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||||
|
std = _SYNC_STD.to(sync_frames.device)
|
||||||
|
sync_frames = (sync_frames - mean) / std
|
||||||
|
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||||
|
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
|
||||||
|
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
|
||||||
|
|
||||||
|
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||||
|
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||||
|
pbar.update(1)
|
||||||
|
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
||||||
|
sync_features = net_video_enc.encode_video_with_sync(
|
||||||
|
sync_input, text_f=text_f, text_mask=text_mask
|
||||||
|
) # [1, T_sync, 768]
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||||
|
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
np.savez(
|
||||||
|
cached_path,
|
||||||
|
clip_features=clip_features.cpu().float().numpy(),
|
||||||
|
sync_features=sync_features.cpu().float().numpy(),
|
||||||
|
duration=float(duration),
|
||||||
|
prompt=np.array(prompt),
|
||||||
|
variant=np.array(model["variant"]),
|
||||||
|
)
|
||||||
|
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||||
|
|
||||||
|
return ({
|
||||||
|
"clip_features": clip_features.cpu(),
|
||||||
|
"sync_features": sync_features.cpu(),
|
||||||
|
"duration": float(duration),
|
||||||
|
"prompt": prompt,
|
||||||
|
"variant": model["variant"],
|
||||||
|
}, float(fps), prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cached(path):
|
||||||
|
data = np.load(path, allow_pickle=False)
|
||||||
|
features = {
|
||||||
|
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||||
|
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||||
|
"duration": float(data["duration"]),
|
||||||
|
}
|
||||||
|
if "prompt" in data:
|
||||||
|
features["prompt"] = str(data["prompt"])
|
||||||
|
if "variant" in data:
|
||||||
|
features["variant"] = str(data["variant"])
|
||||||
|
return features
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY
|
||||||
|
from selva_core.model.lora import apply_lora, load_lora
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaLoraLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"adapter_path": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Path to a LoRA adapter .pt file produced by train_lora.py.",
|
||||||
|
}),
|
||||||
|
"strength": ("FLOAT", {
|
||||||
|
"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05,
|
||||||
|
"tooltip": "Scale applied to all LoRA contributions. "
|
||||||
|
"1.0 = full adapter strength. "
|
||||||
|
"0.0 = effectively disables the adapter. "
|
||||||
|
"Values above 1.0 exaggerate the effect.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Model with LoRA adapter applied — connect to Sampler.",)
|
||||||
|
FUNCTION = "load"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Loads a LoRA adapter produced by train_lora.py and applies it to the generator. "
|
||||||
|
"The base model is not modified — a shallow copy of the model bundle is returned."
|
||||||
|
)
|
||||||
|
|
||||||
|
def load(self, model: dict, adapter_path: str, strength: float) -> tuple:
|
||||||
|
if not adapter_path.strip():
|
||||||
|
raise ValueError("[SelVA LoRA] adapter_path is empty.")
|
||||||
|
|
||||||
|
# Resolve path: allow absolute or relative to ComfyUI base
|
||||||
|
from pathlib import Path
|
||||||
|
p = Path(adapter_path)
|
||||||
|
if not p.is_absolute():
|
||||||
|
p = Path(folder_paths.base_path) / p
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"[SelVA LoRA] Adapter not found: {p}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(str(p), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
# Support both raw state_dict and {state_dict, meta} formats
|
||||||
|
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
||||||
|
state_dict = checkpoint["state_dict"]
|
||||||
|
meta = checkpoint.get("meta", {})
|
||||||
|
else:
|
||||||
|
state_dict = checkpoint
|
||||||
|
meta = {}
|
||||||
|
|
||||||
|
rank = int(meta.get("rank", 16))
|
||||||
|
alpha = float(meta.get("alpha", float(rank)))
|
||||||
|
target = list(meta.get("target", ["attn.qkv"]))
|
||||||
|
|
||||||
|
print(f"[SelVA LoRA] Loading adapter: {p.name}", flush=True)
|
||||||
|
print(f"[SelVA LoRA] rank={rank} alpha={alpha} target={target} strength={strength}",
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
# Shallow-copy the model bundle so the original generator is not mutated
|
||||||
|
patched = {**model}
|
||||||
|
generator = copy.deepcopy(model["generator"])
|
||||||
|
|
||||||
|
n = apply_lora(generator, rank=rank, alpha=alpha, target_suffixes=tuple(target))
|
||||||
|
if n == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[SelVA LoRA] No layers matched target={target}. "
|
||||||
|
"Check that the adapter was trained with the same target suffixes."
|
||||||
|
)
|
||||||
|
load_lora(generator, state_dict)
|
||||||
|
|
||||||
|
# Sanity check: confirm lora_A weights are non-zero (lora_B starts at zero by design)
|
||||||
|
norms = [p.norm().item() for name, p in generator.named_parameters()
|
||||||
|
if "lora_A" in name]
|
||||||
|
if norms:
|
||||||
|
print(f"[SelVA LoRA] lora_A weight norms: min={min(norms):.4f} "
|
||||||
|
f"max={max(norms):.4f} mean={sum(norms)/len(norms):.4f}", flush=True)
|
||||||
|
else:
|
||||||
|
print("[SelVA LoRA] WARNING: no lora_A params found after loading!", flush=True)
|
||||||
|
|
||||||
|
# Apply strength scaling: multiply all lora_B params by strength
|
||||||
|
# (lora_B is initialised to zero, so scaling A is equivalent but less clean)
|
||||||
|
if strength != 1.0:
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in generator.named_parameters():
|
||||||
|
if "lora_B" in name:
|
||||||
|
param.mul_(strength)
|
||||||
|
|
||||||
|
generator.to(model["generator"].parameters().__next__().device)
|
||||||
|
patched["generator"] = generator
|
||||||
|
|
||||||
|
print(f"[SelVA LoRA] Applied {n} LoRA layers.", flush=True)
|
||||||
|
return (patched,)
|
||||||
@@ -0,0 +1,617 @@
|
|||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import traceback
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, soft_empty_cache
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
from selva_core.model.lora import apply_lora, get_lora_state_dict, load_lora
|
||||||
|
|
||||||
|
|
||||||
|
_AUDIO_EXTS = {".wav", ".flac", ".mp3", ".ogg", ".aiff", ".aif"}
|
||||||
|
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data helpers (mirror train_lora.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _load_prompts(data_dir: Path) -> dict:
|
||||||
|
p = data_dir / "prompts.txt"
|
||||||
|
if not p.exists():
|
||||||
|
return {}
|
||||||
|
mapping = {}
|
||||||
|
for line in p.read_text(encoding="utf-8").splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
if ":" in line:
|
||||||
|
fname, prompt = line.split(":", 1)
|
||||||
|
mapping[fname.strip()] = prompt.strip()
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def _find_audio(npz_path: Path) -> Path | None:
|
||||||
|
for ext in _AUDIO_EXTS:
|
||||||
|
c = npz_path.with_suffix(ext)
|
||||||
|
if c.exists():
|
||||||
|
return c
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_audio(path: Path, target_sr: int, duration: float) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
waveform, sr = torchaudio.load(str(path))
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "torchcodec" not in str(e).lower() and "libtorchcodec" not in str(e).lower():
|
||||||
|
raise
|
||||||
|
# torchcodec unavailable (FFmpeg shared libs missing) — fall back to soundfile
|
||||||
|
import soundfile as sf
|
||||||
|
data, sr = sf.read(str(path), always_2d=True) # [frames, channels]
|
||||||
|
waveform = torch.from_numpy(data.T).float() # [channels, frames]
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(0, keepdim=True)
|
||||||
|
waveform = waveform.squeeze(0).float()
|
||||||
|
if sr != target_sr:
|
||||||
|
waveform = torchaudio.functional.resample(
|
||||||
|
waveform.unsqueeze(0), sr, target_sr).squeeze(0)
|
||||||
|
target_len = int(duration * target_sr)
|
||||||
|
if waveform.shape[0] >= target_len:
|
||||||
|
return waveform[:target_len]
|
||||||
|
return F.pad(waveform, (0, target_len - waveform.shape[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_npz(path: Path) -> dict:
|
||||||
|
data = np.load(str(path), allow_pickle=False)
|
||||||
|
bundle = {
|
||||||
|
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||||
|
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||||
|
}
|
||||||
|
if "prompt" in data:
|
||||||
|
bundle["prompt"] = str(data["prompt"])
|
||||||
|
return bundle
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Eval sample
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _eval_sample(generator, feature_utils_orig, dataset, seq_cfg, device, dtype,
|
||||||
|
num_steps: int = 8):
|
||||||
|
"""Run a quick no-CFG inference pass on a random training clip.
|
||||||
|
|
||||||
|
Returns (waveform [1, L] float32 cpu, sample_rate) or (None, None) on failure.
|
||||||
|
Uses fewer ODE steps than inference (8 vs 25) for speed.
|
||||||
|
"""
|
||||||
|
generator.eval()
|
||||||
|
try:
|
||||||
|
_, clip_f_cpu, sync_f_cpu, text_clip_cpu = random.choice(dataset)
|
||||||
|
clip_f = clip_f_cpu.to(device, dtype)
|
||||||
|
sync_f = sync_f_cpu.to(device, dtype)
|
||||||
|
text_clip = text_clip_cpu.to(device, dtype)
|
||||||
|
|
||||||
|
x0 = torch.randn(1, seq_cfg.latent_seq_len, generator.latent_dim,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
|
||||||
|
eval_fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=num_steps)
|
||||||
|
|
||||||
|
def velocity_fn(t, x):
|
||||||
|
return generator.forward(x, clip_f, sync_f, text_clip,
|
||||||
|
t.reshape(1).to(device, dtype))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x1_pred = eval_fm.to_data(velocity_fn, x0)
|
||||||
|
x1_unnorm = generator.unnormalize(x1_pred)
|
||||||
|
|
||||||
|
# feature_utils_orig may be on CPU (offload strategy) — move temporarily
|
||||||
|
orig_device = next(feature_utils_orig.parameters()).device
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils_orig.to(device)
|
||||||
|
try:
|
||||||
|
spec = feature_utils_orig.decode(x1_unnorm)
|
||||||
|
audio = feature_utils_orig.vocode(spec)
|
||||||
|
finally:
|
||||||
|
if orig_device != device:
|
||||||
|
feature_utils_orig.to(orig_device)
|
||||||
|
|
||||||
|
audio = audio.float().cpu()
|
||||||
|
if audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(1)
|
||||||
|
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||||
|
audio = audio.mean(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
return audio.squeeze(0), seq_cfg.sampling_rate # [1, L]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[LoRA Trainer] Eval sample failed: {e}", flush=True)
|
||||||
|
return None, None
|
||||||
|
finally:
|
||||||
|
generator.train()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Loss curve rendering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _smooth_losses(losses: list[float], beta: float = 0.9) -> list[float]:
|
||||||
|
"""Exponential moving average smoothing."""
|
||||||
|
smoothed, ema = [], None
|
||||||
|
for v in losses:
|
||||||
|
ema = v if ema is None else beta * ema + (1 - beta) * v
|
||||||
|
smoothed.append(ema)
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_loss_curve(losses: list[float], log_interval: int,
|
||||||
|
start_step: int = 0, smoothed: list[float] | None = None) -> Image.Image:
|
||||||
|
"""Render a loss curve as a PIL Image."""
|
||||||
|
W, H = 800, 380
|
||||||
|
pl, pr, pt, pb = 70, 20, 25, 45
|
||||||
|
|
||||||
|
img = Image.new("RGB", (W, H), (255, 255, 255))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
pw = W - pl - pr
|
||||||
|
ph = H - pt - pb
|
||||||
|
|
||||||
|
if len(losses) >= 2:
|
||||||
|
lo, hi = min(losses), max(losses)
|
||||||
|
if hi == lo:
|
||||||
|
hi = lo + 1e-6
|
||||||
|
rng = hi - lo
|
||||||
|
|
||||||
|
# Horizontal grid + y-axis labels
|
||||||
|
for i in range(5):
|
||||||
|
y = pt + int(i * ph / 4)
|
||||||
|
val = hi - i * rng / 4
|
||||||
|
draw.line([(pl, y), (W - pr, y)], fill=(220, 220, 220), width=1)
|
||||||
|
draw.text((2, y - 7), f"{val:.4f}", fill=(120, 120, 120))
|
||||||
|
|
||||||
|
# Raw loss line
|
||||||
|
n = len(losses)
|
||||||
|
pts = []
|
||||||
|
for i, v in enumerate(losses):
|
||||||
|
x = pl + int(i * pw / max(n - 1, 1))
|
||||||
|
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||||||
|
pts.append((x, y))
|
||||||
|
draw.line(pts, fill=(200, 220, 255), width=1)
|
||||||
|
|
||||||
|
# Smoothed overlay
|
||||||
|
if smoothed is not None and len(smoothed) >= 2:
|
||||||
|
spts = []
|
||||||
|
for i, v in enumerate(smoothed):
|
||||||
|
x = pl + int(i * pw / max(n - 1, 1))
|
||||||
|
y = pt + int((1.0 - (v - lo) / rng) * ph)
|
||||||
|
spts.append((x, y))
|
||||||
|
draw.line(spts, fill=(66, 133, 244), width=2)
|
||||||
|
|
||||||
|
# x-axis step labels — account for start_step so resumed runs are correct
|
||||||
|
first_step = start_step + log_interval
|
||||||
|
last_step = start_step + n * log_interval
|
||||||
|
for i in range(5):
|
||||||
|
x = pl + int(i * pw / 4)
|
||||||
|
step = int(first_step + i * (last_step - first_step) / 4)
|
||||||
|
draw.text((x - 12, H - pb + 5), str(step), fill=(120, 120, 120))
|
||||||
|
|
||||||
|
# Axes
|
||||||
|
draw.line([(pl, pt), (pl, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.line([(pl, H - pb), (W - pr, H - pb)], fill=(40, 40, 40), width=1)
|
||||||
|
draw.text((pl + 4, 5), "Training Loss", fill=(40, 40, 40))
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
|
||||||
|
"""Convert a PIL Image to a [1, H, W, 3] float32 IMAGE tensor for ComfyUI."""
|
||||||
|
arr = np.array(img).astype(np.float32) / 255.0
|
||||||
|
return torch.from_numpy(arr).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Node
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SelvaLoraTrainer:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"data_dir": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Directory containing .npz feature files and paired audio files.",
|
||||||
|
}),
|
||||||
|
"output_dir": ("STRING", {
|
||||||
|
"default": "lora_output",
|
||||||
|
"tooltip": "Where to save adapter checkpoints.",
|
||||||
|
}),
|
||||||
|
"steps": ("INT", {
|
||||||
|
"default": 2000, "min": 100, "max": 100000,
|
||||||
|
"tooltip": "Total training steps.",
|
||||||
|
}),
|
||||||
|
"rank": ("INT", {
|
||||||
|
"default": 16, "min": 1, "max": 128,
|
||||||
|
"tooltip": "LoRA rank. Higher = more capacity, more VRAM. 16 is a safe default.",
|
||||||
|
}),
|
||||||
|
"lr": ("FLOAT", {
|
||||||
|
"default": 1e-4, "min": 1e-6, "max": 1e-2, "step": 1e-6,
|
||||||
|
"tooltip": "Learning rate.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"alpha": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 256.0, "step": 0.5,
|
||||||
|
"tooltip": "LoRA alpha. 0 = use rank value (scale = 1.0).",
|
||||||
|
}),
|
||||||
|
"target": ("STRING", {
|
||||||
|
"default": "attn.qkv",
|
||||||
|
"tooltip": "Space-separated layer name suffixes to wrap. Default targets all QKV projections. Add 'linear1' for post-attention projections.",
|
||||||
|
}),
|
||||||
|
"batch_size": ("INT", {"default": 4, "min": 1, "max": 32,
|
||||||
|
"tooltip": "Number of clips per training step. Higher = more stable gradients, more VRAM."}),
|
||||||
|
"warmup_steps": ("INT", {"default": 100, "min": 0, "max": 5000}),
|
||||||
|
"grad_accum": ("INT", {"default": 1, "min": 1, "max": 32,
|
||||||
|
"tooltip": "Gradient accumulation steps. Usually 1 when batch_size > 1."}),
|
||||||
|
"save_every": ("INT", {"default": 500, "min": 50, "max": 10000}),
|
||||||
|
"resume_path": ("STRING", {
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Path to a step checkpoint (.pt) to resume training from.",
|
||||||
|
}),
|
||||||
|
"seed": ("INT", {"default": 42}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL", "STRING", "IMAGE")
|
||||||
|
RETURN_NAMES = ("model", "adapter_path", "loss_curve")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"Model with trained LoRA adapter applied — connect directly to Sampler.",
|
||||||
|
"Path to adapter_final.pt — use with SelVA LoRA Loader in future sessions.",
|
||||||
|
"Training loss curve.",
|
||||||
|
)
|
||||||
|
FUNCTION = "train"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = (
|
||||||
|
"Trains a LoRA adapter on a dataset of .npz feature files + paired audio files. "
|
||||||
|
"Blocks the queue for the duration of training. "
|
||||||
|
"Prepare the dataset with SelVA Feature Extractor (set a name to get numbered .npz files) "
|
||||||
|
"and pair each .npz with a clean audio file of the same stem."
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(self, model, data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha=0.0, target="attn.qkv", batch_size=4, warmup_steps=100,
|
||||||
|
grad_accum=1, save_every=500, resume_path="", seed=42):
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
variant = model["variant"]
|
||||||
|
mode = model["mode"]
|
||||||
|
seq_cfg = model["seq_cfg"]
|
||||||
|
feature_utils_orig = model["feature_utils"]
|
||||||
|
|
||||||
|
data_dir = Path(data_dir.strip())
|
||||||
|
|
||||||
|
_out_str = output_dir.strip()
|
||||||
|
_out_p = Path(_out_str)
|
||||||
|
# On Windows a Unix-style path like "/lora_output" is technically absolute
|
||||||
|
# (drive-relative) but the user almost certainly meant a subfolder of the
|
||||||
|
# ComfyUI output directory. Treat any non-absolute path AND any path whose
|
||||||
|
# only "absolute" anchor is a leading slash (no drive letter) as relative to
|
||||||
|
# the ComfyUI output folder.
|
||||||
|
import sys as _sys
|
||||||
|
_unix_style_on_windows = (
|
||||||
|
_sys.platform == "win32"
|
||||||
|
and _out_p.is_absolute()
|
||||||
|
and not _out_p.drive # e.g. Path("/foo").drive == "" on Windows
|
||||||
|
)
|
||||||
|
if not _out_p.is_absolute() or _unix_style_on_windows:
|
||||||
|
_out_p = Path(folder_paths.get_output_directory()) / _out_p.relative_to(_out_p.anchor)
|
||||||
|
print(f"[LoRA Trainer] output_dir resolved to: {_out_p}", flush=True)
|
||||||
|
output_dir = _out_p
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
alpha_val = float(alpha) if alpha > 0.0 else float(rank)
|
||||||
|
target_suffixes = tuple(target.strip().split())
|
||||||
|
|
||||||
|
# --- Load VAE encoder (not present in inference model) ---
|
||||||
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
|
vae_path = _SELVA_DIR / "ext" / vae_name
|
||||||
|
if not vae_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"[LoRA Trainer] VAE weight not found: {vae_path}. "
|
||||||
|
"Run SelVA Model Loader first to auto-download weights."
|
||||||
|
)
|
||||||
|
print("[LoRA Trainer] Loading VAE encoder...", flush=True)
|
||||||
|
# Keep VAE in float32: mel_converter uses torch.stft which requires float32 input.
|
||||||
|
vae_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=str(vae_path),
|
||||||
|
enable_conditions=False,
|
||||||
|
mode=mode,
|
||||||
|
need_vae_encoder=True,
|
||||||
|
).to(device).eval()
|
||||||
|
|
||||||
|
# --- Pre-load dataset ---
|
||||||
|
npz_files = sorted(data_dir.glob("*.npz"))
|
||||||
|
if not npz_files:
|
||||||
|
raise ValueError(f"[LoRA Trainer] No .npz files found in {data_dir}")
|
||||||
|
|
||||||
|
prompt_map = _load_prompts(data_dir)
|
||||||
|
default_prompt = data_dir.name
|
||||||
|
|
||||||
|
print(f"[LoRA Trainer] Pre-loading {len(npz_files)} clip(s)...", flush=True)
|
||||||
|
pbar_load = comfy.utils.ProgressBar(len(npz_files))
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for npz_path in npz_files:
|
||||||
|
audio_path = _find_audio(npz_path)
|
||||||
|
if audio_path is None:
|
||||||
|
print(f" [LoRA Trainer] Warning: no audio for {npz_path.name} — skipping", flush=True)
|
||||||
|
pbar_load.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
bundle = _load_npz(npz_path)
|
||||||
|
prompt = prompt_map.get(npz_path.name, bundle.get("prompt", default_prompt))
|
||||||
|
print(f" {npz_path.name} + {audio_path.name}: '{prompt}'", flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio = _load_audio(audio_path, seq_cfg.sampling_rate, seq_cfg.duration)
|
||||||
|
|
||||||
|
# Audio → latent via VAE (float32: mel_converter/stft require float32)
|
||||||
|
# encode_audio is @inference_mode — .clone() exits inference mode
|
||||||
|
audio_b = audio.unsqueeze(0).to(device)
|
||||||
|
dist = vae_utils.encode_audio(audio_b)
|
||||||
|
# VAE outputs [B, latent_dim, T]; generator expects [B, T, latent_dim]
|
||||||
|
x1 = dist.mode().clone().transpose(1, 2).cpu()
|
||||||
|
# STFT rounding can produce ±1 frame — pad or trim to exact seq length
|
||||||
|
tgt = seq_cfg.latent_seq_len
|
||||||
|
if x1.shape[1] < tgt:
|
||||||
|
x1 = F.pad(x1, (0, 0, 0, tgt - x1.shape[1]))
|
||||||
|
elif x1.shape[1] > tgt:
|
||||||
|
x1 = x1[:, :tgt, :]
|
||||||
|
|
||||||
|
# Text → CLIP features (reuse already-loaded CLIP from inference model)
|
||||||
|
text_clip = feature_utils_orig.encode_text_clip([prompt]).cpu()
|
||||||
|
|
||||||
|
# Pad/trim clip and sync features to fixed seq lengths — clips from
|
||||||
|
# shorter videos have fewer frames and would cause stack() to fail
|
||||||
|
clip_f = bundle["clip_features"] # [1, N_clip, 1024]
|
||||||
|
c_tgt = seq_cfg.clip_seq_len
|
||||||
|
if clip_f.shape[1] < c_tgt:
|
||||||
|
clip_f = F.pad(clip_f, (0, 0, 0, c_tgt - clip_f.shape[1]))
|
||||||
|
elif clip_f.shape[1] > c_tgt:
|
||||||
|
clip_f = clip_f[:, :c_tgt, :]
|
||||||
|
|
||||||
|
sync_f = bundle["sync_features"] # [1, N_sync, 768]
|
||||||
|
s_tgt = seq_cfg.sync_seq_len
|
||||||
|
if sync_f.shape[1] < s_tgt:
|
||||||
|
sync_f = F.pad(sync_f, (0, 0, 0, s_tgt - sync_f.shape[1]))
|
||||||
|
elif sync_f.shape[1] > s_tgt:
|
||||||
|
sync_f = sync_f[:, :s_tgt, :]
|
||||||
|
|
||||||
|
dataset.append((x1, clip_f, sync_f, text_clip))
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [LoRA Trainer] Warning: failed {npz_path.name}: {e}", flush=True)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
pbar_load.update(1)
|
||||||
|
|
||||||
|
# VAE no longer needed — free memory
|
||||||
|
del vae_utils
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("[LoRA Trainer] No clips could be loaded.")
|
||||||
|
print(f"[LoRA Trainer] {len(dataset)} clip(s) ready.", flush=True)
|
||||||
|
|
||||||
|
# ComfyUI executes nodes inside torch.inference_mode(). Inference tensors
|
||||||
|
# can't participate in autograd even with enable_grad — disable inference
|
||||||
|
# mode entirely so deepcopy, apply_lora, and the training loop all run
|
||||||
|
# with a clean autograd context.
|
||||||
|
with torch.inference_mode(False), torch.enable_grad():
|
||||||
|
return self._train_inner(
|
||||||
|
model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_inner(
|
||||||
|
self, model, dataset, feature_utils_orig, seq_cfg,
|
||||||
|
device, dtype, variant, mode,
|
||||||
|
data_dir, output_dir, steps, rank, lr,
|
||||||
|
alpha_val, target_suffixes, batch_size, warmup_steps,
|
||||||
|
grad_accum, save_every, resume_path, seed,
|
||||||
|
):
|
||||||
|
# --- Prepare generator copy with LoRA ---
|
||||||
|
generator = copy.deepcopy(model["generator"]).to(device, dtype)
|
||||||
|
|
||||||
|
n_lora = apply_lora(generator, rank=rank, alpha=alpha_val,
|
||||||
|
target_suffixes=target_suffixes)
|
||||||
|
if n_lora == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"[LoRA Trainer] No layers matched target={target_suffixes}. "
|
||||||
|
"Check the 'target' field."
|
||||||
|
)
|
||||||
|
print(f"[LoRA Trainer] Wrapped {n_lora} layers (rank={rank}, alpha={alpha_val})", flush=True)
|
||||||
|
|
||||||
|
for name, p in generator.named_parameters():
|
||||||
|
p.requires_grad_("lora_" in name)
|
||||||
|
|
||||||
|
generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=seq_cfg.clip_seq_len,
|
||||||
|
sync_seq_len=seq_cfg.sync_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Optimizer + scheduler ---
|
||||||
|
lora_params = [p for p in generator.parameters() if p.requires_grad]
|
||||||
|
optimizer = torch.optim.AdamW(lora_params, lr=lr, weight_decay=1e-2)
|
||||||
|
|
||||||
|
def lr_lambda(s):
|
||||||
|
return s / max(1, warmup_steps) if s < warmup_steps else 1.0
|
||||||
|
|
||||||
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=25)
|
||||||
|
|
||||||
|
# --- Resume ---
|
||||||
|
start_step = 0
|
||||||
|
if resume_path.strip():
|
||||||
|
ckpt = torch.load(resume_path.strip(), map_location="cpu", weights_only=False)
|
||||||
|
if "step" not in ckpt:
|
||||||
|
raise ValueError("[LoRA Trainer] Checkpoint has no step info.")
|
||||||
|
start_step = ckpt["step"]
|
||||||
|
if start_step >= steps:
|
||||||
|
raise ValueError(
|
||||||
|
f"[LoRA Trainer] Checkpoint already at step {start_step} >= steps {steps}."
|
||||||
|
)
|
||||||
|
load_lora(generator, ckpt["state_dict"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler"])
|
||||||
|
print(f"[LoRA Trainer] Resumed from step {start_step}.", flush=True)
|
||||||
|
|
||||||
|
# --- Training loop ---
|
||||||
|
generator.train()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
log_interval = 50
|
||||||
|
remaining = steps - start_step
|
||||||
|
pbar_train = comfy.utils.ProgressBar(remaining)
|
||||||
|
loss_history = []
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"variant": variant,
|
||||||
|
"rank": rank,
|
||||||
|
"alpha": alpha_val,
|
||||||
|
"target": list(target_suffixes),
|
||||||
|
"steps": steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n[LoRA Trainer] Training {remaining} steps "
|
||||||
|
f"(step {start_step + 1} → {steps}, batch_size={batch_size})\n", flush=True)
|
||||||
|
|
||||||
|
last_step = start_step
|
||||||
|
completed = False
|
||||||
|
try:
|
||||||
|
for step in range(start_step + 1, steps + 1):
|
||||||
|
batch = random.choices(dataset, k=batch_size)
|
||||||
|
x1_list, clip_list, sync_list, text_list = zip(*batch)
|
||||||
|
|
||||||
|
x1 = torch.stack([x.squeeze(0) for x in x1_list]).to(device, dtype)
|
||||||
|
clip_f = torch.stack([x.squeeze(0) for x in clip_list]).to(device, dtype)
|
||||||
|
sync_f = torch.stack([x.squeeze(0) for x in sync_list]).to(device, dtype)
|
||||||
|
text_clip = torch.stack([x.squeeze(0) for x in text_list]).to(device, dtype)
|
||||||
|
|
||||||
|
generator.normalize(x1)
|
||||||
|
|
||||||
|
t = torch.rand(batch_size, device=device, dtype=dtype)
|
||||||
|
x0 = torch.randn_like(x1)
|
||||||
|
xt = fm.get_conditional_flow(x0, x1, t)
|
||||||
|
|
||||||
|
v_pred = generator.forward(xt, clip_f, sync_f, text_clip, t)
|
||||||
|
loss = fm.loss(v_pred, x0, x1).mean() / grad_accum
|
||||||
|
loss.backward()
|
||||||
|
running_loss += loss.item() * grad_accum
|
||||||
|
|
||||||
|
if step % grad_accum == 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(lora_params, max_norm=1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % log_interval == 0:
|
||||||
|
avg = running_loss / log_interval
|
||||||
|
loss_history.append(avg)
|
||||||
|
lr_now = scheduler.get_last_lr()[0]
|
||||||
|
print(f"[LoRA Trainer] step {step:5d}/{steps} "
|
||||||
|
f"loss={avg:.4f} lr={lr_now:.2e} bs={batch_size}", flush=True)
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
# Live preview: send updated loss curve to ComfyUI frontend
|
||||||
|
preview_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||||
|
smoothed=_smooth_losses(loss_history))
|
||||||
|
pbar_train.update_absolute(
|
||||||
|
step - start_step, remaining, ("JPEG", preview_img, 800)
|
||||||
|
)
|
||||||
|
|
||||||
|
if step % save_every == 0 or step == steps:
|
||||||
|
ckpt_path = output_dir / f"adapter_step{step:05d}.pt"
|
||||||
|
torch.save({
|
||||||
|
"state_dict": get_lora_state_dict(generator),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"step": step,
|
||||||
|
"meta": meta,
|
||||||
|
}, ckpt_path)
|
||||||
|
print(f"[LoRA Trainer] Saved {ckpt_path}", flush=True)
|
||||||
|
|
||||||
|
# Save a quick eval sample next to the checkpoint
|
||||||
|
wav, sr = _eval_sample(generator, feature_utils_orig,
|
||||||
|
dataset, seq_cfg, device, dtype)
|
||||||
|
if wav is not None:
|
||||||
|
wav_path = output_dir / f"sample_step{step:05d}.wav"
|
||||||
|
try:
|
||||||
|
torchaudio.save(str(wav_path), wav, sr)
|
||||||
|
except RuntimeError:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(str(wav_path), wav.squeeze(0).numpy(), sr)
|
||||||
|
print(f"[LoRA Trainer] Sample saved: {wav_path}", flush=True)
|
||||||
|
|
||||||
|
last_step = step
|
||||||
|
pbar_train.update(1)
|
||||||
|
|
||||||
|
completed = True
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Save adapter and loss curves whether training completed or was cancelled.
|
||||||
|
# Skip if we never completed a single step (nothing useful to save).
|
||||||
|
if loss_history:
|
||||||
|
if completed:
|
||||||
|
# Normal completion — use adapter_final.pt (increment if exists)
|
||||||
|
final_path = output_dir / "adapter_final.pt"
|
||||||
|
if final_path.exists():
|
||||||
|
i = 1
|
||||||
|
while (output_dir / f"adapter_final_{i:03d}.pt").exists():
|
||||||
|
i += 1
|
||||||
|
final_path = output_dir / f"adapter_final_{i:03d}.pt"
|
||||||
|
label = "Done"
|
||||||
|
else:
|
||||||
|
# Cancelled — include the step number so the file is useful for resume
|
||||||
|
final_path = output_dir / f"adapter_cancelled_step{last_step:05d}.pt"
|
||||||
|
label = f"Cancelled at step {last_step}"
|
||||||
|
|
||||||
|
torch.save({"state_dict": get_lora_state_dict(generator), "meta": meta}, final_path)
|
||||||
|
(output_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
print(f"\n[LoRA Trainer] {label}. Adapter saved to {final_path}", flush=True)
|
||||||
|
|
||||||
|
smoothed = _smooth_losses(loss_history)
|
||||||
|
raw_img = _draw_loss_curve(loss_history, log_interval, start_step)
|
||||||
|
smoothed_img = _draw_loss_curve(loss_history, log_interval, start_step,
|
||||||
|
smoothed=smoothed)
|
||||||
|
raw_img.save(str(output_dir / "loss_raw.png"))
|
||||||
|
smoothed_img.save(str(output_dir / "loss_smoothed.png"))
|
||||||
|
print(f"[LoRA Trainer] Loss curves saved to {output_dir}", flush=True)
|
||||||
|
|
||||||
|
# Reached only on normal completion (exception re-raises past this point)
|
||||||
|
generator.eval()
|
||||||
|
generator.to(next(model["generator"].parameters()).device)
|
||||||
|
patched = {**model, "generator": generator}
|
||||||
|
|
||||||
|
loss_curve = _pil_to_tensor(smoothed_img)
|
||||||
|
return (patched, str(final_path), loss_curve)
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
|
||||||
|
|
||||||
|
# Variant → (generator filename, mode, has_bigvgan)
|
||||||
|
_VARIANTS = {
|
||||||
|
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||||
|
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||||
|
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||||
|
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||||
|
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
|
||||||
|
|
||||||
|
|
||||||
|
_HF_REPO = "jnwnlee/SelVA"
|
||||||
|
|
||||||
|
# filename → (hf_repo_path, expected_md5 or None to skip check)
|
||||||
|
# Note: 44k generators are named 44khz in the HF repo; md5=None since the
|
||||||
|
# original download_utils had the wrong filenames so those md5s are unverified.
|
||||||
|
_WEIGHTS = {
|
||||||
|
"video_enc_sup_5.pth": ("weights/video_enc_sup_5.pth", "ff09a6dc36148536ee4db97eba081d05"),
|
||||||
|
"generator_small_16k_sup_5.pth": ("weights/generator_small_16k_sup_5.pth", "1cb0f0deec52de37f67b1fd9965337d0"),
|
||||||
|
"generator_small_44k_sup_5.pth": ("weights/generator_small_44khz_sup_5.pth", None),
|
||||||
|
"generator_medium_44k_sup_5.pth":("weights/generator_medium_44khz_sup_5.pth", None),
|
||||||
|
"generator_large_44k_sup_5.pth": ("weights/generator_large_44khz_sup_5.pth", None),
|
||||||
|
"v1-16.pth": ("ext_weights/v1-16.pth", "69f56803f59a549a1a507c93859fd4d7"),
|
||||||
|
"v1-44.pth": ("ext_weights/v1-44.pth", "fab020275fa44c6589820ce025191600"),
|
||||||
|
"best_netG.pt": ("ext_weights/best_netG.pt", "eeaf372a38a9c31c362120aba2dde292"),
|
||||||
|
"synchformer_state_dict.pth": ("ext_weights/synchformer_state_dict.pth", "5b2f5594b0730f70e41e549b7c94390c"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _md5(path):
|
||||||
|
import hashlib
|
||||||
|
h = hashlib.md5()
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(8 * 1024 * 1024), b""):
|
||||||
|
h.update(chunk)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure(filename, subdir=None):
|
||||||
|
"""Return path to weight file. Re-downloads if missing or MD5 mismatch."""
|
||||||
|
import shutil
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
|
||||||
|
dest_path = dest_dir / filename
|
||||||
|
|
||||||
|
entry = _WEIGHTS.get(filename)
|
||||||
|
if entry is None:
|
||||||
|
raise ValueError(f"[SelVA] Unknown weight file: {filename}")
|
||||||
|
repo_path, expected_md5 = entry
|
||||||
|
|
||||||
|
if dest_path.exists():
|
||||||
|
if expected_md5 is None:
|
||||||
|
return str(dest_path)
|
||||||
|
actual = _md5(dest_path)
|
||||||
|
if actual == expected_md5:
|
||||||
|
return str(dest_path)
|
||||||
|
print(f"[SelVA] {filename}: MD5 mismatch ({actual} ≠ {expected_md5}), re-downloading...", flush=True)
|
||||||
|
dest_path.unlink()
|
||||||
|
|
||||||
|
print(f"[SelVA] Downloading {filename} from {_HF_REPO}...", flush=True)
|
||||||
|
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
cached = hf_hub_download(repo_id=_HF_REPO, filename=repo_path)
|
||||||
|
shutil.copy2(cached, dest_path)
|
||||||
|
print(f"[SelVA] Saved to {dest_path}", flush=True)
|
||||||
|
return str(dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _synchformer_path():
|
||||||
|
"""Return synchformer path, reusing models/prismaudio/ if already present."""
|
||||||
|
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
|
||||||
|
if prismaudio_path.exists():
|
||||||
|
return str(prismaudio_path)
|
||||||
|
return _ensure("synchformer_state_dict.pth")
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaModelLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"variant": (list(_VARIANTS.keys()), {
|
||||||
|
"tooltip": "Model size and output sample rate. small_16k is fastest (16 kHz). 44k variants output 44.1 kHz. larger = better quality, more VRAM.",
|
||||||
|
}),
|
||||||
|
"precision": (["bf16", "fp16", "fp32"], {
|
||||||
|
"tooltip": "Compute dtype. bf16 is recommended on Ampere+ GPUs. fp16 for older NVIDIA hardware. fp32 if you see NaN outputs.",
|
||||||
|
}),
|
||||||
|
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"], {
|
||||||
|
"tooltip": "auto picks keep_in_vram if ≥16 GB VRAM is free, otherwise offload_to_cpu. offload_to_cpu moves weights to RAM between nodes, saving VRAM at the cost of speed.",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("SELVA_MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
|
||||||
|
FUNCTION = "load_model"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
|
||||||
|
|
||||||
|
def load_model(self, variant, precision, offload_strategy):
|
||||||
|
from selva_core.model.networks_generator import get_my_mmaudio
|
||||||
|
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||||
|
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||||
|
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||||
|
|
||||||
|
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
|
||||||
|
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
|
||||||
|
precision = "fp16"
|
||||||
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||||
|
strategy = determine_offload_strategy(offload_strategy)
|
||||||
|
|
||||||
|
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
|
||||||
|
video_enc_path = _ensure("video_enc_sup_5.pth")
|
||||||
|
gen_path = _ensure(gen_filename)
|
||||||
|
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||||
|
vae_path = _ensure(vae_name, subdir="ext")
|
||||||
|
synch_path = _synchformer_path()
|
||||||
|
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||||
|
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||||
|
net_video_enc.load_weights(
|
||||||
|
torch.load(video_enc_path, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
||||||
|
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||||
|
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
||||||
|
net_generator.load_weights(
|
||||||
|
torch.load(gen_path, map_location="cpu", weights_only=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
||||||
|
feature_utils = FeaturesUtils(
|
||||||
|
tod_vae_ckpt=vae_path,
|
||||||
|
synchformer_ckpt=synch_path,
|
||||||
|
enable_conditions=True,
|
||||||
|
mode=mode,
|
||||||
|
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||||
|
need_vae_encoder=False,
|
||||||
|
).to(device, dtype).eval()
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
net_video_enc.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
|
||||||
|
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
||||||
|
|
||||||
|
return ({
|
||||||
|
"generator": net_generator,
|
||||||
|
"video_enc": net_video_enc,
|
||||||
|
"feature_utils": feature_utils,
|
||||||
|
"variant": variant,
|
||||||
|
"mode": mode,
|
||||||
|
"strategy": strategy,
|
||||||
|
"dtype": dtype,
|
||||||
|
"seq_cfg": seq_cfg,
|
||||||
|
},)
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||||
|
|
||||||
|
|
||||||
|
class SelvaSampler:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("SELVA_MODEL",),
|
||||||
|
"features": ("SELVA_FEATURES",),
|
||||||
|
"prompt": ("STRING", {
|
||||||
|
"default": "", "multiline": True,
|
||||||
|
"tooltip": "Sound description for CLIP text conditioning. Leave empty to reuse the prompt from the Feature Extractor (wire its prompt output here). Changing this without re-extracting features shifts CLIP conditioning but not sync features.",
|
||||||
|
}),
|
||||||
|
"negative_prompt": ("STRING", {
|
||||||
|
"default": "", "multiline": False,
|
||||||
|
"tooltip": "Sounds to suppress, e.g. 'speech, music, wind noise'. Steered away from via CFG. Leave empty for unconditional guidance baseline.",
|
||||||
|
}),
|
||||||
|
"duration": ("FLOAT", {
|
||||||
|
"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||||
|
"tooltip": "Output audio length in seconds. 0 = match the video duration stored in features.",
|
||||||
|
}),
|
||||||
|
"steps": ("INT", {"default": 25, "min": 1, "max": 200,
|
||||||
|
"tooltip": "Euler steps for the flow matching ODE. 25 is the SelVA default. Diminishing returns above 50; below 10 may sound rough."}),
|
||||||
|
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1,
|
||||||
|
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 3–7."}),
|
||||||
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"normalize": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO",)
|
||||||
|
RETURN_NAMES = ("audio",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
|
||||||
|
FUNCTION = "generate"
|
||||||
|
CATEGORY = SELVA_CATEGORY
|
||||||
|
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||||
|
|
||||||
|
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
|
||||||
|
import dataclasses
|
||||||
|
from selva_core.model.flow_matching import FlowMatching
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
dtype = model["dtype"]
|
||||||
|
strategy = model["strategy"]
|
||||||
|
net_generator = model["generator"]
|
||||||
|
feature_utils = model["feature_utils"]
|
||||||
|
|
||||||
|
# Validate that features were extracted with the same model variant
|
||||||
|
feat_variant = features.get("variant")
|
||||||
|
if feat_variant is not None and feat_variant != model["variant"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
|
||||||
|
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve prompt: use override if given, otherwise fall back to features prompt
|
||||||
|
if not prompt or not prompt.strip():
|
||||||
|
prompt = features.get("prompt", "")
|
||||||
|
if prompt:
|
||||||
|
print(f"[SelVA] Using prompt from features: '{prompt[:60]}'", flush=True)
|
||||||
|
else:
|
||||||
|
print("[SelVA] Warning: no prompt in features or sampler — CLIP text conditioning will be empty.", flush=True)
|
||||||
|
|
||||||
|
# Resolve duration
|
||||||
|
if duration <= 0:
|
||||||
|
if "duration" not in features:
|
||||||
|
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
||||||
|
duration = features["duration"]
|
||||||
|
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
||||||
|
|
||||||
|
# Derive sequence config for this duration from the model's mode template
|
||||||
|
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
||||||
|
sample_rate = seq_cfg.sampling_rate
|
||||||
|
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(device)
|
||||||
|
feature_utils.to(device)
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||||
|
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||||
|
|
||||||
|
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||||
|
|
||||||
|
# Update model rotary position embeddings for actual feature shapes and duration.
|
||||||
|
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
|
||||||
|
net_generator.update_seq_lengths(
|
||||||
|
latent_seq_len=seq_cfg.latent_seq_len,
|
||||||
|
clip_seq_len=clip_f.shape[1],
|
||||||
|
sync_seq_len=sync_f.shape[1],
|
||||||
|
)
|
||||||
|
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Encode text conditioning
|
||||||
|
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||||
|
|
||||||
|
# Encode negative prompt (or use empty conditions)
|
||||||
|
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||||
|
if negative_prompt.strip() else None
|
||||||
|
|
||||||
|
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||||
|
empty_conditions = net_generator.get_empty_conditions(
|
||||||
|
bs=1, negative_text_features=neg_text_clip
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initial noise (MPS doesn't support torch.Generator on device)
|
||||||
|
gen_device = "cpu" if device.type == "mps" else device
|
||||||
|
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
||||||
|
x0 = torch.randn(
|
||||||
|
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||||
|
device=gen_device, dtype=dtype, generator=rng,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Flow matching ODE (Euler)
|
||||||
|
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||||
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
|
def ode_wrapper_tracked(t, x):
|
||||||
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
pbar.update(1)
|
||||||
|
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||||
|
|
||||||
|
try:
|
||||||
|
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||||
|
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||||
|
|
||||||
|
# Decode: latent → mel → audio
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
x1_unnorm = net_generator.unnormalize(x1)
|
||||||
|
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
|
||||||
|
audio = feature_utils.vocode(spec) # mel → waveform
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
|
||||||
|
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if strategy == "offload_to_cpu":
|
||||||
|
net_generator.to(get_offload_device())
|
||||||
|
feature_utils.to(get_offload_device())
|
||||||
|
soft_empty_cache()
|
||||||
|
|
||||||
|
# Ensure [1, 1, samples] and normalize to [-1,1]
|
||||||
|
audio = audio.float()
|
||||||
|
if audio.dim() == 2:
|
||||||
|
audio = audio.unsqueeze(1)
|
||||||
|
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||||
|
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||||
|
|
||||||
|
if normalize:
|
||||||
|
peak = audio.abs().max().clamp(min=1e-8)
|
||||||
|
audio = (audio / peak).clamp(-1, 1)
|
||||||
|
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||||
|
|
||||||
|
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
import torch
|
|
||||||
import comfy.model_management as mm
|
|
||||||
import comfy.utils
|
|
||||||
|
|
||||||
from .utils import (
|
|
||||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
|
||||||
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
|
||||||
)
|
|
||||||
from .sampler import _substitute_empty_features
|
|
||||||
|
|
||||||
|
|
||||||
class PrismAudioTextOnly:
|
|
||||||
@classmethod
|
|
||||||
def INPUT_TYPES(cls):
|
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"model": ("PRISMAUDIO_MODEL",),
|
|
||||||
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
|
||||||
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
|
||||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
|
|
||||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("AUDIO",)
|
|
||||||
RETURN_NAMES = ("audio",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
CATEGORY = PRISMAUDIO_CATEGORY
|
|
||||||
|
|
||||||
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
|
||||||
device = get_device()
|
|
||||||
dtype = model["dtype"]
|
|
||||||
strategy = model["strategy"]
|
|
||||||
diffusion = model["model"]
|
|
||||||
|
|
||||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
|
||||||
|
|
||||||
# Encode text with T5-Gemma
|
|
||||||
text_features = _encode_text_t5(text_prompt, device, dtype)
|
|
||||||
|
|
||||||
# Build metadata: tuple of one dict per sample
|
|
||||||
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
|
||||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
|
||||||
# These will be substituted with learned empty embeddings after conditioning
|
|
||||||
sample_meta = {
|
|
||||||
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
|
||||||
"text_features": text_features.to(device, dtype=dtype),
|
|
||||||
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
|
||||||
"video_exist": torch.tensor(False),
|
|
||||||
}
|
|
||||||
metadata = (sample_meta,)
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.model.to(device)
|
|
||||||
diffusion.conditioner.to(device)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
|
||||||
conditioning = diffusion.conditioner(metadata, device)
|
|
||||||
|
|
||||||
# Substitute empty features for video/sync
|
|
||||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
|
||||||
|
|
||||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
|
||||||
|
|
||||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
|
||||||
gen_device = "cpu" if device.type == "mps" else device
|
|
||||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
|
||||||
noise = torch.randn(
|
|
||||||
[1, IO_CHANNELS, latent_length],
|
|
||||||
generator=generator,
|
|
||||||
device=gen_device,
|
|
||||||
).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
|
||||||
|
|
||||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
|
||||||
|
|
||||||
def on_step(info):
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
fakes = sample_discrete_euler(
|
|
||||||
diffusion.model,
|
|
||||||
noise,
|
|
||||||
steps,
|
|
||||||
callback=on_step,
|
|
||||||
**cond_inputs,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
batch_cfg=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
fakes_f = fakes.float()
|
|
||||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.model.to(get_offload_device())
|
|
||||||
diffusion.conditioner.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
diffusion.pretransform.to(device)
|
|
||||||
|
|
||||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
|
||||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
|
||||||
audio = diffusion.pretransform.decode(fakes_f)
|
|
||||||
|
|
||||||
if strategy == "offload_to_cpu":
|
|
||||||
diffusion.pretransform.to(get_offload_device())
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
# Peak normalize then clamp
|
|
||||||
audio = audio.float()
|
|
||||||
pre_norm_std = audio.std().item()
|
|
||||||
pre_norm_peak = audio.abs().max().item()
|
|
||||||
peak = audio.abs().max().clamp(min=1e-8)
|
|
||||||
audio = (audio / peak).clamp(-1, 1)
|
|
||||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
|
||||||
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
|
||||||
|
|
||||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
|
||||||
|
|
||||||
|
|
||||||
# T5-Gemma encoder singleton
|
|
||||||
_t5_model = None
|
|
||||||
_t5_tokenizer = None
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_text_t5(text, device, dtype):
|
|
||||||
"""Encode text using T5-Gemma.
|
|
||||||
|
|
||||||
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
|
||||||
FeaturesUtils.encode_t5_text() implementation.
|
|
||||||
No truncation applied (matching reference behavior).
|
|
||||||
"""
|
|
||||||
global _t5_model, _t5_tokenizer
|
|
||||||
|
|
||||||
if _t5_model is None:
|
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
||||||
model_id = "google/t5gemma-l-l-ul2-it"
|
|
||||||
token = resolve_hf_token()
|
|
||||||
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
|
||||||
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
|
||||||
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
|
||||||
_t5_model.eval()
|
|
||||||
|
|
||||||
_t5_model.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
tokens = _t5_tokenizer(
|
|
||||||
text,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = _t5_model(**tokens)
|
|
||||||
|
|
||||||
# Move T5 off GPU after encoding to save VRAM
|
|
||||||
_t5_model.to("cpu")
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
|
||||||
+3
-46
@@ -1,21 +1,7 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import folder_paths
|
|
||||||
import comfy.model_management as mm
|
import comfy.model_management as mm
|
||||||
|
|
||||||
PRISMAUDIO_CATEGORY = "PrismAudio"
|
SELVA_CATEGORY = "SelVA"
|
||||||
SAMPLE_RATE = 44100
|
|
||||||
DOWNSAMPLING_RATIO = 2048
|
|
||||||
IO_CHANNELS = 64
|
|
||||||
|
|
||||||
def get_prismaudio_model_dir():
|
|
||||||
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
|
||||||
return model_dir
|
|
||||||
|
|
||||||
def register_model_folder():
|
|
||||||
model_dir = get_prismaudio_model_dir()
|
|
||||||
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
|
||||||
|
|
||||||
def get_device():
|
def get_device():
|
||||||
return mm.get_torch_device()
|
return mm.get_torch_device()
|
||||||
@@ -23,42 +9,13 @@ def get_device():
|
|||||||
def get_offload_device():
|
def get_offload_device():
|
||||||
return mm.unet_offload_device()
|
return mm.unet_offload_device()
|
||||||
|
|
||||||
def get_free_memory(device=None):
|
|
||||||
if device is None:
|
|
||||||
device = get_device()
|
|
||||||
return mm.get_free_memory(device)
|
|
||||||
|
|
||||||
def soft_empty_cache():
|
def soft_empty_cache():
|
||||||
mm.soft_empty_cache()
|
mm.soft_empty_cache()
|
||||||
|
|
||||||
def determine_precision(preference, device):
|
|
||||||
if preference != "auto":
|
|
||||||
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
|
||||||
if device.type == "cpu":
|
|
||||||
return torch.float32
|
|
||||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
||||||
return torch.bfloat16
|
|
||||||
return torch.float16
|
|
||||||
|
|
||||||
def determine_offload_strategy(preference):
|
def determine_offload_strategy(preference):
|
||||||
if preference != "auto":
|
if preference != "auto":
|
||||||
return preference
|
return preference
|
||||||
free_mem = get_free_memory()
|
free_mem = mm.get_free_memory(get_device())
|
||||||
gb = free_mem / (1024 ** 3)
|
if free_mem / (1024 ** 3) >= 16:
|
||||||
if gb >= 24:
|
|
||||||
return "keep_in_vram"
|
return "keep_in_vram"
|
||||||
else:
|
|
||||||
return "offload_to_cpu"
|
return "offload_to_cpu"
|
||||||
|
|
||||||
def try_import_flash_attn():
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
return flash_attn
|
|
||||||
except ImportError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def resolve_hf_token():
|
|
||||||
env_token = os.environ.get("HF_TOKEN")
|
|
||||||
if env_token:
|
|
||||||
return env_token
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
"""
|
|
||||||
PrismAudio core inference modules.
|
|
||||||
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
|
|
||||||
Only inference-critical code — no training, no JAX/TF dependencies.
|
|
||||||
"""
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
{
|
|
||||||
"model_type": "diffusion_cond",
|
|
||||||
"sample_size": 397312,
|
|
||||||
"sample_rate": 44100,
|
|
||||||
"audio_channels": 2,
|
|
||||||
"model": {
|
|
||||||
"pretransform": {
|
|
||||||
"type": "autoencoder",
|
|
||||||
"iterate_batch": true,
|
|
||||||
"config": {
|
|
||||||
"encoder": {
|
|
||||||
"type": "oobleck",
|
|
||||||
"config": {
|
|
||||||
"in_channels": 2,
|
|
||||||
"channels": 128,
|
|
||||||
"c_mults": [1, 2, 4, 8, 16],
|
|
||||||
"strides": [2, 4, 4, 8, 8],
|
|
||||||
"latent_dim": 128,
|
|
||||||
"use_snake": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"decoder": {
|
|
||||||
"type": "oobleck",
|
|
||||||
"config": {
|
|
||||||
"out_channels": 2,
|
|
||||||
"channels": 128,
|
|
||||||
"c_mults": [1, 2, 4, 8, 16],
|
|
||||||
"strides": [2, 4, 4, 8, 8],
|
|
||||||
"latent_dim": 64,
|
|
||||||
"use_snake": true,
|
|
||||||
"final_tanh": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"bottleneck": {
|
|
||||||
"type": "vae"
|
|
||||||
},
|
|
||||||
"latent_dim": 64,
|
|
||||||
"downsampling_ratio": 2048,
|
|
||||||
"io_channels": 2
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"conditioning": {
|
|
||||||
"configs": [
|
|
||||||
{
|
|
||||||
"id": "video_features",
|
|
||||||
"type": "cond_mlp",
|
|
||||||
"config": {
|
|
||||||
"dim": 1024,
|
|
||||||
"output_dim": 1024
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "text_features",
|
|
||||||
"type": "cond_mlp",
|
|
||||||
"config": {
|
|
||||||
"dim": 1024,
|
|
||||||
"output_dim": 1024
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "sync_features",
|
|
||||||
"type": "sync_mlp",
|
|
||||||
"config": {
|
|
||||||
"dim": 768,
|
|
||||||
"output_dim": 1024
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"cond_dim": 768
|
|
||||||
},
|
|
||||||
"diffusion": {
|
|
||||||
"cross_attention_cond_ids": ["video_features","text_features"],
|
|
||||||
"add_cond_ids": ["video_features"],
|
|
||||||
"sync_cond_ids": ["sync_features"],
|
|
||||||
"type": "dit",
|
|
||||||
"diffusion_objective": "rectified_flow",
|
|
||||||
"config": {
|
|
||||||
"io_channels": 64,
|
|
||||||
"embed_dim": 1024,
|
|
||||||
"depth": 24,
|
|
||||||
"num_heads": 16,
|
|
||||||
"cond_token_dim": 1024,
|
|
||||||
"add_token_dim": 1024,
|
|
||||||
"sync_token_dim": 1024,
|
|
||||||
"project_cond_tokens": false,
|
|
||||||
"transformer_type": "continuous_transformer",
|
|
||||||
"attn_kwargs":{
|
|
||||||
"qk_norm": "rns"
|
|
||||||
},
|
|
||||||
"use_gated": true,
|
|
||||||
"use_sync_gated": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"io_channels": 64
|
|
||||||
},
|
|
||||||
"training": {
|
|
||||||
"use_ema": true,
|
|
||||||
"log_loss_info": false,
|
|
||||||
"cfg_dropout_prob": 0.1,
|
|
||||||
"pre_encoded": true,
|
|
||||||
"timestep_sampler": "trunc_logit_normal",
|
|
||||||
"optimizer_configs": {
|
|
||||||
"diffusion": {
|
|
||||||
"optimizer": {
|
|
||||||
"type": "AdamW",
|
|
||||||
"config": {
|
|
||||||
"lr": 1e-4,
|
|
||||||
"betas": [0.9, 0.999],
|
|
||||||
"weight_decay": 1e-3
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"scheduler": {
|
|
||||||
"type": "InverseLR",
|
|
||||||
"config": {
|
|
||||||
"inv_gamma": 100000,
|
|
||||||
"power": 0.5,
|
|
||||||
"warmup": 0.99
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"demo": {
|
|
||||||
"demo_every": 5000,
|
|
||||||
"demo_steps": 24,
|
|
||||||
"num_demos": 10,
|
|
||||||
"demo_cond": [
|
|
||||||
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
|
|
||||||
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
|
|
||||||
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
|
|
||||||
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
|
|
||||||
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
|
|
||||||
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
|
|
||||||
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
|
|
||||||
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
|
|
||||||
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
|
|
||||||
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
|
|
||||||
],
|
|
||||||
"demo_cfg_scales": [5]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,413 +0,0 @@
|
|||||||
"""
|
|
||||||
Model factory functions for PrismAudio inference.
|
|
||||||
|
|
||||||
Extracted from:
|
|
||||||
- PrismAudio/models/factory.py
|
|
||||||
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
|
|
||||||
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
|
|
||||||
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
|
|
||||||
|
|
||||||
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
|
|
||||||
Only inference-critical factory functions are retained.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import typing as tp
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_from_config(model_config):
|
|
||||||
model_type = model_config.get('model_type', None)
|
|
||||||
|
|
||||||
assert model_type is not None, 'model_type must be specified in model config'
|
|
||||||
|
|
||||||
if model_type == 'autoencoder':
|
|
||||||
return create_autoencoder_from_config(model_config)
|
|
||||||
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
|
|
||||||
return create_diffusion_cond_from_config(model_config)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
||||||
|
|
||||||
|
|
||||||
def create_pretransform_from_config(pretransform_config, sample_rate):
|
|
||||||
pretransform_type = pretransform_config.get('type', None)
|
|
||||||
|
|
||||||
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
|
||||||
|
|
||||||
if pretransform_type == 'autoencoder':
|
|
||||||
from prismaudio_core.models.pretransforms import AutoencoderPretransform
|
|
||||||
|
|
||||||
# Create fake top-level config to pass sample rate to autoencoder constructor
|
|
||||||
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
|
||||||
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
|
||||||
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
|
||||||
|
|
||||||
scale = pretransform_config.get("scale", 1.0)
|
|
||||||
model_half = pretransform_config.get("model_half", False)
|
|
||||||
iterate_batch = pretransform_config.get("iterate_batch", False)
|
|
||||||
chunked = pretransform_config.get("chunked", False)
|
|
||||||
|
|
||||||
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
|
||||||
elif pretransform_type == 'wavelet':
|
|
||||||
raise NotImplementedError("wavelet pretransform type is not supported")
|
|
||||||
elif pretransform_type == 'pqmf':
|
|
||||||
from prismaudio_core.models.pretransforms import PQMFPretransform
|
|
||||||
pqmf_config = pretransform_config["config"]
|
|
||||||
pretransform = PQMFPretransform(**pqmf_config)
|
|
||||||
elif pretransform_type == 'dac_pretrained':
|
|
||||||
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
|
|
||||||
pretrained_dac_config = pretransform_config["config"]
|
|
||||||
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
|
||||||
elif pretransform_type == "audiocraft_pretrained":
|
|
||||||
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
|
|
||||||
|
|
||||||
audiocraft_config = pretransform_config["config"]
|
|
||||||
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
|
||||||
|
|
||||||
enable_grad = pretransform_config.get('enable_grad', False)
|
|
||||||
pretransform.enable_grad = enable_grad
|
|
||||||
|
|
||||||
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
|
||||||
|
|
||||||
return pretransform
|
|
||||||
|
|
||||||
|
|
||||||
def create_bottleneck_from_config(bottleneck_config):
|
|
||||||
bottleneck_type = bottleneck_config.get('type', None)
|
|
||||||
|
|
||||||
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
|
||||||
|
|
||||||
if bottleneck_type == 'tanh':
|
|
||||||
from prismaudio_core.models.bottleneck import TanhBottleneck
|
|
||||||
bottleneck = TanhBottleneck()
|
|
||||||
elif bottleneck_type == 'vae':
|
|
||||||
from prismaudio_core.models.bottleneck import VAEBottleneck
|
|
||||||
bottleneck = VAEBottleneck()
|
|
||||||
elif bottleneck_type == 'rvq':
|
|
||||||
from prismaudio_core.models.bottleneck import RVQBottleneck
|
|
||||||
|
|
||||||
quantizer_params = {
|
|
||||||
"dim": 128,
|
|
||||||
"codebook_size": 1024,
|
|
||||||
"num_quantizers": 8,
|
|
||||||
"decay": 0.99,
|
|
||||||
"kmeans_init": True,
|
|
||||||
"kmeans_iters": 50,
|
|
||||||
"threshold_ema_dead_code": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
quantizer_params.update(bottleneck_config["config"])
|
|
||||||
|
|
||||||
bottleneck = RVQBottleneck(**quantizer_params)
|
|
||||||
elif bottleneck_type == "dac_rvq":
|
|
||||||
from prismaudio_core.models.bottleneck import DACRVQBottleneck
|
|
||||||
|
|
||||||
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
|
||||||
|
|
||||||
elif bottleneck_type == 'rvq_vae':
|
|
||||||
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
|
|
||||||
|
|
||||||
quantizer_params = {
|
|
||||||
"dim": 128,
|
|
||||||
"codebook_size": 1024,
|
|
||||||
"num_quantizers": 8,
|
|
||||||
"decay": 0.99,
|
|
||||||
"kmeans_init": True,
|
|
||||||
"kmeans_iters": 50,
|
|
||||||
"threshold_ema_dead_code": 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
quantizer_params.update(bottleneck_config["config"])
|
|
||||||
|
|
||||||
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
|
||||||
|
|
||||||
elif bottleneck_type == 'dac_rvq_vae':
|
|
||||||
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
|
|
||||||
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
|
||||||
elif bottleneck_type == 'l2_norm':
|
|
||||||
from prismaudio_core.models.bottleneck import L2Bottleneck
|
|
||||||
bottleneck = L2Bottleneck()
|
|
||||||
elif bottleneck_type == "wasserstein":
|
|
||||||
from prismaudio_core.models.bottleneck import WassersteinBottleneck
|
|
||||||
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
|
||||||
elif bottleneck_type == "fsq":
|
|
||||||
from prismaudio_core.models.bottleneck import FSQBottleneck
|
|
||||||
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
|
||||||
|
|
||||||
requires_grad = bottleneck_config.get('requires_grad', True)
|
|
||||||
if not requires_grad:
|
|
||||||
for param in bottleneck.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
return bottleneck
|
|
||||||
|
|
||||||
|
|
||||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
|
||||||
"""Create an AudioAutoencoder from a config dictionary.
|
|
||||||
|
|
||||||
Originally in PrismAudio/models/autoencoders.py.
|
|
||||||
"""
|
|
||||||
from prismaudio_core.models.autoencoders import (
|
|
||||||
AudioAutoencoder,
|
|
||||||
create_encoder_from_config,
|
|
||||||
create_decoder_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
ae_config = config["model"]
|
|
||||||
|
|
||||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
|
||||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
|
||||||
|
|
||||||
bottleneck = ae_config.get("bottleneck", None)
|
|
||||||
|
|
||||||
latent_dim = ae_config.get("latent_dim", None)
|
|
||||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
|
||||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
|
||||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
|
||||||
io_channels = ae_config.get("io_channels", None)
|
|
||||||
assert io_channels is not None, "io_channels must be specified in model config"
|
|
||||||
sample_rate = config.get("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
|
||||||
|
|
||||||
in_channels = ae_config.get("in_channels", None)
|
|
||||||
out_channels = ae_config.get("out_channels", None)
|
|
||||||
|
|
||||||
pretransform = ae_config.get("pretransform", None)
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
|
||||||
|
|
||||||
if bottleneck is not None:
|
|
||||||
bottleneck = create_bottleneck_from_config(bottleneck)
|
|
||||||
|
|
||||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
|
||||||
|
|
||||||
return AudioAutoencoder(
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
io_channels=io_channels,
|
|
||||||
latent_dim=latent_dim,
|
|
||||||
downsampling_ratio=downsampling_ratio,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
bottleneck=bottleneck,
|
|
||||||
pretransform=pretransform,
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
soft_clip=soft_clip
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
|
|
||||||
"""Create a MultiConditioner from a conditioning config dictionary.
|
|
||||||
|
|
||||||
Originally in PrismAudio/models/conditioners.py.
|
|
||||||
"""
|
|
||||||
from prismaudio_core.models.conditioners import (
|
|
||||||
MultiConditioner,
|
|
||||||
T5Conditioner,
|
|
||||||
CLAPTextConditioner,
|
|
||||||
CLIPTextConditioner,
|
|
||||||
MetaCLIPTextConditioner,
|
|
||||||
CLAPAudioConditioner,
|
|
||||||
Cond_MLP,
|
|
||||||
Global_MLP,
|
|
||||||
Sync_MLP,
|
|
||||||
Cond_MLP_1,
|
|
||||||
Cond_ConvMLP,
|
|
||||||
Cond_MLP_Global,
|
|
||||||
Cond_MLP_Global_1,
|
|
||||||
Cond_MLP_Global_2,
|
|
||||||
Video_Global,
|
|
||||||
Video_Sync,
|
|
||||||
Text_Linear,
|
|
||||||
CLIPConditioner,
|
|
||||||
IntConditioner,
|
|
||||||
NumberConditioner,
|
|
||||||
PhonemeConditioner,
|
|
||||||
TokenizerLUTConditioner,
|
|
||||||
PretransformConditioner,
|
|
||||||
mm_unchang,
|
|
||||||
)
|
|
||||||
from prismaudio_core.models.utils import load_ckpt_state_dict
|
|
||||||
|
|
||||||
conditioners = {}
|
|
||||||
cond_dim = config["cond_dim"]
|
|
||||||
|
|
||||||
default_keys = config.get("default_keys", {})
|
|
||||||
|
|
||||||
for conditioner_info in config["configs"]:
|
|
||||||
id = conditioner_info["id"]
|
|
||||||
|
|
||||||
conditioner_type = conditioner_info["type"]
|
|
||||||
|
|
||||||
conditioner_config = {"output_dim": cond_dim}
|
|
||||||
|
|
||||||
conditioner_config.update(conditioner_info["config"])
|
|
||||||
if conditioner_type == "t5":
|
|
||||||
conditioners[id] = T5Conditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "clap_text":
|
|
||||||
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "clip_text":
|
|
||||||
conditioners[id] = CLIPTextConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "metaclip_text":
|
|
||||||
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "clap_audio":
|
|
||||||
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_mlp":
|
|
||||||
conditioners[id] = Cond_MLP(**conditioner_config)
|
|
||||||
elif conditioner_type == "global_mlp":
|
|
||||||
conditioners[id] = Global_MLP(**conditioner_config)
|
|
||||||
elif conditioner_type == "sync_mlp":
|
|
||||||
conditioners[id] = Sync_MLP(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_mlp_1":
|
|
||||||
conditioners[id] = Cond_MLP_1(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_convmlp":
|
|
||||||
conditioners[id] = Cond_ConvMLP(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_mlp_global":
|
|
||||||
conditioners[id] = Cond_MLP_Global(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_mlp_global_1":
|
|
||||||
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
|
|
||||||
elif conditioner_type == "cond_mlp_global_2":
|
|
||||||
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
|
|
||||||
elif conditioner_type == "video_global":
|
|
||||||
conditioners[id] = Video_Global(**conditioner_config)
|
|
||||||
elif conditioner_type == "video_sync":
|
|
||||||
conditioners[id] = Video_Sync(**conditioner_config)
|
|
||||||
elif conditioner_type == "text_linear":
|
|
||||||
conditioners[id] = Text_Linear(**conditioner_config)
|
|
||||||
elif conditioner_type == "video_clip":
|
|
||||||
conditioners[id] = CLIPConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "int":
|
|
||||||
conditioners[id] = IntConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "number":
|
|
||||||
conditioners[id] = NumberConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "phoneme":
|
|
||||||
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "lut":
|
|
||||||
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
|
||||||
elif conditioner_type == "pretransform":
|
|
||||||
sample_rate = conditioner_config.pop("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
|
||||||
|
|
||||||
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
|
||||||
|
|
||||||
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
|
||||||
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
|
||||||
|
|
||||||
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
|
||||||
elif conditioner_type == "mm_unchang":
|
|
||||||
conditioners[id] = mm_unchang(**conditioner_config)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
|
||||||
|
|
||||||
return MultiConditioner(conditioners, default_keys=default_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|
||||||
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
|
|
||||||
|
|
||||||
Originally in PrismAudio/models/diffusion.py.
|
|
||||||
"""
|
|
||||||
from prismaudio_core.models.diffusion import (
|
|
||||||
ConditionedDiffusionModelWrapper,
|
|
||||||
MMConditionedDiffusionModelWrapper,
|
|
||||||
UNetCFG1DWrapper,
|
|
||||||
UNet1DCondWrapper,
|
|
||||||
DiTWrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config = config["model"]
|
|
||||||
|
|
||||||
model_type = config["model_type"]
|
|
||||||
|
|
||||||
diffusion_config = model_config.get('diffusion', None)
|
|
||||||
assert diffusion_config is not None, "Must specify diffusion config"
|
|
||||||
|
|
||||||
diffusion_model_type = diffusion_config.get('type', None)
|
|
||||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
|
||||||
|
|
||||||
diffusion_model_config = diffusion_config.get('config', None)
|
|
||||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
|
||||||
|
|
||||||
if diffusion_model_type == 'adp_cfg_1d':
|
|
||||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
|
||||||
elif diffusion_model_type == 'adp_1d':
|
|
||||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
|
||||||
elif diffusion_model_type == 'dit':
|
|
||||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
|
||||||
elif diffusion_model_type == 'mmdit':
|
|
||||||
raise NotImplementedError("mmdit diffusion model type is not supported")
|
|
||||||
|
|
||||||
io_channels = model_config.get('io_channels', None)
|
|
||||||
assert io_channels is not None, "Must specify io_channels in model config"
|
|
||||||
|
|
||||||
sample_rate = config.get('sample_rate', None)
|
|
||||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
|
||||||
|
|
||||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
|
||||||
|
|
||||||
conditioning_config = model_config.get('conditioning', None)
|
|
||||||
|
|
||||||
conditioner = None
|
|
||||||
if conditioning_config is not None:
|
|
||||||
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
|
||||||
|
|
||||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
|
||||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
|
||||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
|
||||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
|
||||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
|
||||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
|
||||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
|
||||||
zero_init = diffusion_config.get('zero_init', False)
|
|
||||||
pretransform = model_config.get("pretransform", None)
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
|
||||||
min_input_length = pretransform.downsampling_ratio
|
|
||||||
else:
|
|
||||||
min_input_length = 1
|
|
||||||
|
|
||||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
|
||||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
|
||||||
elif diffusion_model_type == "dit":
|
|
||||||
min_input_length *= diffusion_model.model.patch_size
|
|
||||||
|
|
||||||
# Get the proper wrapper class
|
|
||||||
|
|
||||||
extra_kwargs = {}
|
|
||||||
|
|
||||||
if model_type == "mm_diffusion_cond":
|
|
||||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
|
||||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
|
||||||
|
|
||||||
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
|
||||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
|
||||||
|
|
||||||
elif model_type == "diffusion_prior":
|
|
||||||
raise NotImplementedError("diffusion_prior model type is not supported")
|
|
||||||
|
|
||||||
return wrapper_fn(
|
|
||||||
diffusion_model,
|
|
||||||
conditioner,
|
|
||||||
min_input_length=min_input_length,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
cross_attn_cond_ids=cross_attention_ids,
|
|
||||||
global_cond_ids=global_cond_ids,
|
|
||||||
input_concat_ids=input_concat_ids,
|
|
||||||
prepend_cond_ids=prepend_cond_ids,
|
|
||||||
add_cond_ids=add_cond_ids,
|
|
||||||
sync_cond_ids=sync_cond_ids,
|
|
||||||
pretransform=pretransform,
|
|
||||||
io_channels=io_channels,
|
|
||||||
zero_init=zero_init,
|
|
||||||
**extra_kwargs
|
|
||||||
)
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
from .sampling import sample_discrete_euler
|
|
||||||
from .utils import set_audio_channels, prepare_audio
|
|
||||||
|
|
||||||
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
|
||||||
"""Discrete Euler sampler for rectified flow, with optional callback.
|
|
||||||
|
|
||||||
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
|
||||||
Original uses tqdm internally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The diffusion model (DiTWrapper)
|
|
||||||
x: Initial noise tensor [B, C, T]
|
|
||||||
steps: Number of sampling steps
|
|
||||||
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
|
||||||
callback: Optional callable({"i": step, "x": current_x}) for progress
|
|
||||||
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
|
||||||
sync_cond, cfg_scale, batch_cfg, etc.
|
|
||||||
"""
|
|
||||||
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
|
||||||
dt = t_next - t_curr
|
|
||||||
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
|
||||||
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
|
||||||
if callback is not None:
|
|
||||||
callback({"i": i, "x": x})
|
|
||||||
|
|
||||||
return x
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchaudio import transforms as T
|
|
||||||
|
|
||||||
|
|
||||||
def set_audio_channels(audio, target_channels):
|
|
||||||
"""Convert audio tensor to target number of channels.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: Audio tensor of shape [B, C, T]
|
|
||||||
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Audio tensor with the target number of channels.
|
|
||||||
"""
|
|
||||||
if target_channels == 1:
|
|
||||||
# Convert to mono
|
|
||||||
audio = audio.mean(1, keepdim=True)
|
|
||||||
elif target_channels == 2:
|
|
||||||
# Convert to stereo
|
|
||||||
if audio.shape[1] == 1:
|
|
||||||
audio = audio.repeat(1, 2, 1)
|
|
||||||
elif audio.shape[1] > 2:
|
|
||||||
audio = audio[:, :2, :]
|
|
||||||
return audio
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
|
||||||
"""Resample, pad/trim, and convert channels of an audio tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
|
||||||
in_sr: Input sample rate
|
|
||||||
target_sr: Target sample rate
|
|
||||||
target_length: Target length in samples (padded or cropped)
|
|
||||||
target_channels: Target number of channels
|
|
||||||
device: Torch device to place the audio on
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Audio tensor of shape [B, target_channels, target_length] on device.
|
|
||||||
"""
|
|
||||||
audio = audio.to(device)
|
|
||||||
|
|
||||||
if in_sr != target_sr:
|
|
||||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
|
||||||
audio = resample_tf(audio)
|
|
||||||
|
|
||||||
# Add batch dimension
|
|
||||||
if audio.dim() == 1:
|
|
||||||
audio = audio.unsqueeze(0).unsqueeze(0)
|
|
||||||
elif audio.dim() == 2:
|
|
||||||
audio = audio.unsqueeze(0)
|
|
||||||
|
|
||||||
# Pad or crop to target_length
|
|
||||||
if audio.shape[-1] < target_length:
|
|
||||||
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
|
||||||
elif audio.shape[-1] > target_length:
|
|
||||||
audio = audio[:, :, :target_length]
|
|
||||||
|
|
||||||
audio = set_audio_channels(audio, target_channels)
|
|
||||||
|
|
||||||
return audio
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
"""
|
|
||||||
PrismAudio model modules for inference.
|
|
||||||
|
|
||||||
Re-exports create_model_from_config from the factory module.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from prismaudio_core.factory import create_model_from_config
|
|
||||||
|
|
||||||
__all__ = ["create_model_from_config"]
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,821 +0,0 @@
|
|||||||
import torch
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from torchaudio import transforms as T
|
|
||||||
from alias_free_torch import Activation1d
|
|
||||||
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
|
||||||
from typing import Literal, Dict, Any
|
|
||||||
|
|
||||||
from .blocks import SnakeBeta
|
|
||||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
|
||||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
|
||||||
from .pretransforms import Pretransform
|
|
||||||
from .utils import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
|
||||||
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
|
||||||
import torchaudio.transforms as T
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if in_sr != target_sr:
|
|
||||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
|
||||||
audio = resample_tf(audio)
|
|
||||||
|
|
||||||
if audio.shape[0] > target_channels:
|
|
||||||
audio = audio[:target_channels]
|
|
||||||
elif audio.shape[0] < target_channels:
|
|
||||||
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
|
||||||
|
|
||||||
if audio.shape[-1] < target_length:
|
|
||||||
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
|
||||||
elif audio.shape[-1] > target_length:
|
|
||||||
audio = audio[..., :target_length]
|
|
||||||
|
|
||||||
return audio.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
|
||||||
from prismaudio_core.factory import create_pretransform_from_config
|
|
||||||
return create_pretransform_from_config(pretransform, sample_rate)
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_create_bottleneck_from_config(bottleneck):
|
|
||||||
from prismaudio_core.factory import create_bottleneck_from_config
|
|
||||||
return create_bottleneck_from_config(bottleneck)
|
|
||||||
|
|
||||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
|
||||||
if activation == "elu":
|
|
||||||
act = nn.ELU()
|
|
||||||
elif activation == "snake":
|
|
||||||
act = SnakeBeta(channels)
|
|
||||||
elif activation == "none":
|
|
||||||
act = nn.Identity()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown activation {activation}")
|
|
||||||
|
|
||||||
if antialias:
|
|
||||||
act = Activation1d(act)
|
|
||||||
|
|
||||||
return act
|
|
||||||
|
|
||||||
class ResidualUnit(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dilation = dilation
|
|
||||||
|
|
||||||
padding = (dilation * (7-1)) // 2
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
|
||||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
|
||||||
kernel_size=7, dilation=dilation, padding=padding),
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
|
||||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
|
||||||
kernel_size=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
res = x
|
|
||||||
|
|
||||||
#x = checkpoint(self.layers, x)
|
|
||||||
x = self.layers(x)
|
|
||||||
|
|
||||||
return x + res
|
|
||||||
|
|
||||||
class EncoderBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(
|
|
||||||
ResidualUnit(in_channels=in_channels,
|
|
||||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
|
||||||
ResidualUnit(in_channels=in_channels,
|
|
||||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
|
||||||
ResidualUnit(in_channels=in_channels,
|
|
||||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
|
||||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
|
||||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if use_nearest_upsample:
|
|
||||||
upsample_layer = nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
|
||||||
WNConv1d(in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=2*stride,
|
|
||||||
stride=1,
|
|
||||||
bias=False,
|
|
||||||
padding='same')
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
|
||||||
upsample_layer,
|
|
||||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
||||||
dilation=1, use_snake=use_snake),
|
|
||||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
||||||
dilation=3, use_snake=use_snake),
|
|
||||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
|
||||||
dilation=9, use_snake=use_snake),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
class OobleckEncoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
in_channels=2,
|
|
||||||
channels=128,
|
|
||||||
latent_dim=32,
|
|
||||||
c_mults = [1, 2, 4, 8],
|
|
||||||
strides = [2, 4, 8, 8],
|
|
||||||
use_snake=False,
|
|
||||||
antialias_activation=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
c_mults = [1] + c_mults
|
|
||||||
|
|
||||||
self.depth = len(c_mults)
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
|
||||||
]
|
|
||||||
|
|
||||||
for i in range(self.depth-1):
|
|
||||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
|
||||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
|
|
||||||
class OobleckDecoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
out_channels=2,
|
|
||||||
channels=128,
|
|
||||||
latent_dim=32,
|
|
||||||
c_mults = [1, 2, 4, 8],
|
|
||||||
strides = [2, 4, 8, 8],
|
|
||||||
use_snake=False,
|
|
||||||
antialias_activation=False,
|
|
||||||
use_nearest_upsample=False,
|
|
||||||
final_tanh=True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
c_mults = [1] + c_mults
|
|
||||||
|
|
||||||
self.depth = len(c_mults)
|
|
||||||
|
|
||||||
layers = [
|
|
||||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
|
||||||
]
|
|
||||||
|
|
||||||
for i in range(self.depth-1, 0, -1):
|
|
||||||
layers += [DecoderBlock(
|
|
||||||
in_channels=c_mults[i]*channels,
|
|
||||||
out_channels=c_mults[i-1]*channels,
|
|
||||||
stride=strides[i-1],
|
|
||||||
use_snake=use_snake,
|
|
||||||
antialias_activation=antialias_activation,
|
|
||||||
use_nearest_upsample=use_nearest_upsample
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
|
||||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
|
||||||
nn.Tanh() if final_tanh else nn.Identity()
|
|
||||||
]
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
|
|
||||||
class DACEncoderWrapper(nn.Module):
|
|
||||||
def __init__(self, in_channels=1, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
from dac.model.dac import Encoder as DACEncoder
|
|
||||||
|
|
||||||
latent_dim = kwargs.pop("latent_dim", None)
|
|
||||||
|
|
||||||
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
|
||||||
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
|
||||||
self.latent_dim = latent_dim
|
|
||||||
|
|
||||||
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
|
||||||
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
|
||||||
|
|
||||||
if in_channels != 1:
|
|
||||||
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
x = self.proj_out(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DACDecoderWrapper(nn.Module):
|
|
||||||
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
from dac.model.dac import Decoder as DACDecoder
|
|
||||||
|
|
||||||
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
|
||||||
|
|
||||||
self.latent_dim = latent_dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.decoder(x)
|
|
||||||
|
|
||||||
class AudioAutoencoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
latent_dim,
|
|
||||||
downsampling_ratio,
|
|
||||||
sample_rate,
|
|
||||||
io_channels=2,
|
|
||||||
bottleneck: Bottleneck = None,
|
|
||||||
pretransform: Pretransform = None,
|
|
||||||
in_channels = None,
|
|
||||||
out_channels = None,
|
|
||||||
soft_clip = False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.downsampling_ratio = downsampling_ratio
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
|
|
||||||
self.latent_dim = latent_dim
|
|
||||||
self.io_channels = io_channels
|
|
||||||
self.in_channels = io_channels
|
|
||||||
self.out_channels = io_channels
|
|
||||||
|
|
||||||
self.min_length = self.downsampling_ratio
|
|
||||||
|
|
||||||
if in_channels is not None:
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
if out_channels is not None:
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.bottleneck = bottleneck
|
|
||||||
|
|
||||||
self.encoder = encoder
|
|
||||||
|
|
||||||
self.decoder = decoder
|
|
||||||
|
|
||||||
self.pretransform = pretransform
|
|
||||||
|
|
||||||
self.soft_clip = soft_clip
|
|
||||||
|
|
||||||
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
|
||||||
|
|
||||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
|
||||||
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
if self.pretransform is not None and not skip_pretransform:
|
|
||||||
if self.pretransform.enable_grad:
|
|
||||||
if iterate_batch:
|
|
||||||
audios = []
|
|
||||||
for i in range(audio.shape[0]):
|
|
||||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
|
||||||
audio = torch.cat(audios, dim=0)
|
|
||||||
else:
|
|
||||||
audio = self.pretransform.encode(audio)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
if iterate_batch:
|
|
||||||
audios = []
|
|
||||||
for i in range(audio.shape[0]):
|
|
||||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
|
||||||
audio = torch.cat(audios, dim=0)
|
|
||||||
else:
|
|
||||||
audio = self.pretransform.encode(audio)
|
|
||||||
|
|
||||||
if self.encoder is not None:
|
|
||||||
if iterate_batch:
|
|
||||||
latents = []
|
|
||||||
for i in range(audio.shape[0]):
|
|
||||||
latents.append(self.encoder(audio[i:i+1]))
|
|
||||||
latents = torch.cat(latents, dim=0)
|
|
||||||
else:
|
|
||||||
latents = self.encoder(audio)
|
|
||||||
else:
|
|
||||||
latents = audio
|
|
||||||
|
|
||||||
if self.bottleneck is not None:
|
|
||||||
# TODO: Add iterate batch logic, needs to merge the info dicts
|
|
||||||
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
|
||||||
|
|
||||||
info.update(bottleneck_info)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return latents, info
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def decode(self, latents, iterate_batch=False, **kwargs):
|
|
||||||
|
|
||||||
if self.bottleneck is not None:
|
|
||||||
if iterate_batch:
|
|
||||||
decoded = []
|
|
||||||
for i in range(latents.shape[0]):
|
|
||||||
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
|
||||||
latents = torch.cat(decoded, dim=0)
|
|
||||||
else:
|
|
||||||
latents = self.bottleneck.decode(latents)
|
|
||||||
|
|
||||||
if iterate_batch:
|
|
||||||
decoded = []
|
|
||||||
for i in range(latents.shape[0]):
|
|
||||||
decoded.append(self.decoder(latents[i:i+1]))
|
|
||||||
decoded = torch.cat(decoded, dim=0)
|
|
||||||
else:
|
|
||||||
decoded = self.decoder(latents, **kwargs)
|
|
||||||
|
|
||||||
if self.pretransform is not None:
|
|
||||||
if self.pretransform.enable_grad:
|
|
||||||
if iterate_batch:
|
|
||||||
decodeds = []
|
|
||||||
for i in range(decoded.shape[0]):
|
|
||||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
|
||||||
decoded = torch.cat(decodeds, dim=0)
|
|
||||||
else:
|
|
||||||
decoded = self.pretransform.decode(decoded)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
if iterate_batch:
|
|
||||||
decodeds = []
|
|
||||||
for i in range(latents.shape[0]):
|
|
||||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
|
||||||
decoded = torch.cat(decodeds, dim=0)
|
|
||||||
else:
|
|
||||||
decoded = self.pretransform.decode(decoded)
|
|
||||||
|
|
||||||
if self.soft_clip:
|
|
||||||
decoded = torch.tanh(decoded)
|
|
||||||
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens, **kwargs):
|
|
||||||
'''
|
|
||||||
Decode discrete tokens to audio
|
|
||||||
Only works with discrete autoencoders
|
|
||||||
'''
|
|
||||||
|
|
||||||
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
|
||||||
|
|
||||||
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio_for_encoder(self, audio, in_sr):
|
|
||||||
'''
|
|
||||||
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
|
||||||
If the model is mono, stereo audio will be converted to mono.
|
|
||||||
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
|
||||||
Audio will be resampled to the model's sample rate.
|
|
||||||
The output will have batch size 1 and be shape (1 x Channels x Length)
|
|
||||||
'''
|
|
||||||
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
|
||||||
|
|
||||||
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
|
||||||
'''
|
|
||||||
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
|
||||||
The audio in that list can be of different lengths and channels.
|
|
||||||
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
|
||||||
All audio will be resampled to the model's sample rate.
|
|
||||||
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
|
||||||
If the model is mono, all audio will be converted to mono.
|
|
||||||
The output will be a tensor of shape (Batch x Channels x Length)
|
|
||||||
'''
|
|
||||||
batch_size = len(audio_list)
|
|
||||||
if isinstance(in_sr_list, int):
|
|
||||||
in_sr_list = [in_sr_list]*batch_size
|
|
||||||
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
|
||||||
new_audio = []
|
|
||||||
max_length = 0
|
|
||||||
# resample & find the max length
|
|
||||||
for i in range(batch_size):
|
|
||||||
audio = audio_list[i]
|
|
||||||
in_sr = in_sr_list[i]
|
|
||||||
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
|
||||||
# batchsize 1 was given by accident. Just squeeze it.
|
|
||||||
audio = audio.squeeze(0)
|
|
||||||
elif len(audio.shape) == 1:
|
|
||||||
# Mono signal, channel dimension is missing, unsqueeze it in
|
|
||||||
audio = audio.unsqueeze(0)
|
|
||||||
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
|
||||||
# Resample audio
|
|
||||||
if in_sr != self.sample_rate:
|
|
||||||
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
|
||||||
audio = resample_tf(audio)
|
|
||||||
new_audio.append(audio)
|
|
||||||
if audio.shape[-1] > max_length:
|
|
||||||
max_length = audio.shape[-1]
|
|
||||||
# Pad every audio to the same length, multiple of model's downsampling ratio
|
|
||||||
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
|
||||||
for i in range(batch_size):
|
|
||||||
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
|
||||||
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
|
||||||
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
|
||||||
# convert to tensor
|
|
||||||
return torch.stack(new_audio)
|
|
||||||
|
|
||||||
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
|
||||||
'''
|
|
||||||
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
|
||||||
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
|
||||||
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
|
||||||
# and therefore you likely could use the same values with decode_audio.
|
|
||||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
|
||||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
|
||||||
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
|
||||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
|
||||||
Smaller chunk_size uses less memory, but more compute.
|
|
||||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
|
||||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
|
||||||
'''
|
|
||||||
if not chunked:
|
|
||||||
# default behavior. Encode the entire audio in parallel
|
|
||||||
return self.encode(audio, **kwargs)
|
|
||||||
else:
|
|
||||||
# CHUNKED ENCODING
|
|
||||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
|
||||||
samples_per_latent = self.downsampling_ratio
|
|
||||||
total_size = audio.shape[2] # in samples
|
|
||||||
batch_size = audio.shape[0]
|
|
||||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
|
||||||
overlap *= samples_per_latent # converting metric in latents to samples
|
|
||||||
hop_size = chunk_size - overlap
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
|
||||||
chunk = audio[:,:,i:i+chunk_size]
|
|
||||||
chunks.append(chunk)
|
|
||||||
if i+chunk_size != total_size:
|
|
||||||
# Final chunk
|
|
||||||
chunk = audio[:,:,-chunk_size:]
|
|
||||||
chunks.append(chunk)
|
|
||||||
chunks = torch.stack(chunks)
|
|
||||||
num_chunks = chunks.shape[0]
|
|
||||||
# Note: y_size might be a different value from the latent length used in diffusion training
|
|
||||||
# because we can encode audio of varying lengths
|
|
||||||
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
|
||||||
y_size = total_size // samples_per_latent
|
|
||||||
# Create an empty latent, we will populate it with chunks as we encode them
|
|
||||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
|
||||||
for i in range(num_chunks):
|
|
||||||
x_chunk = chunks[i,:]
|
|
||||||
# encode the chunk
|
|
||||||
y_chunk = self.encode(x_chunk)
|
|
||||||
# figure out where to put the audio along the time domain
|
|
||||||
if i == num_chunks-1:
|
|
||||||
# final chunk always goes at the end
|
|
||||||
t_end = y_size
|
|
||||||
t_start = t_end - y_chunk.shape[2]
|
|
||||||
else:
|
|
||||||
t_start = i * hop_size // samples_per_latent
|
|
||||||
t_end = t_start + chunk_size // samples_per_latent
|
|
||||||
# remove the edges of the overlaps
|
|
||||||
ol = overlap//samples_per_latent//2
|
|
||||||
chunk_start = 0
|
|
||||||
chunk_end = y_chunk.shape[2]
|
|
||||||
if i > 0:
|
|
||||||
# no overlap for the start of the first chunk
|
|
||||||
t_start += ol
|
|
||||||
chunk_start += ol
|
|
||||||
if i < num_chunks-1:
|
|
||||||
# no overlap for the end of the last chunk
|
|
||||||
t_end -= ol
|
|
||||||
chunk_end -= ol
|
|
||||||
# paste the chunked audio into our y_final output audio
|
|
||||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
|
||||||
return y_final
|
|
||||||
|
|
||||||
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
|
||||||
'''
|
|
||||||
Decode latents to audio.
|
|
||||||
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
|
||||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
|
||||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
|
||||||
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
|
||||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
|
||||||
Smaller chunk_size uses less memory, but more compute.
|
|
||||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
|
||||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
|
||||||
'''
|
|
||||||
if not chunked:
|
|
||||||
# default behavior. Decode the entire latent in parallel
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
else:
|
|
||||||
# chunked decoding
|
|
||||||
hop_size = chunk_size - overlap
|
|
||||||
total_size = latents.shape[2]
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
chunks = []
|
|
||||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
|
||||||
chunk = latents[:,:,i:i+chunk_size]
|
|
||||||
chunks.append(chunk)
|
|
||||||
if i+chunk_size != total_size:
|
|
||||||
# Final chunk
|
|
||||||
chunk = latents[:,:,-chunk_size:]
|
|
||||||
chunks.append(chunk)
|
|
||||||
chunks = torch.stack(chunks)
|
|
||||||
num_chunks = chunks.shape[0]
|
|
||||||
# samples_per_latent is just the downsampling ratio
|
|
||||||
samples_per_latent = self.downsampling_ratio
|
|
||||||
# Create an empty waveform, we will populate it with chunks as decode them
|
|
||||||
y_size = total_size * samples_per_latent
|
|
||||||
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
|
||||||
for i in range(num_chunks):
|
|
||||||
x_chunk = chunks[i,:]
|
|
||||||
# decode the chunk
|
|
||||||
y_chunk = self.decode(x_chunk)
|
|
||||||
# figure out where to put the audio along the time domain
|
|
||||||
if i == num_chunks-1:
|
|
||||||
# final chunk always goes at the end
|
|
||||||
t_end = y_size
|
|
||||||
t_start = t_end - y_chunk.shape[2]
|
|
||||||
else:
|
|
||||||
t_start = i * hop_size * samples_per_latent
|
|
||||||
t_end = t_start + chunk_size * samples_per_latent
|
|
||||||
# remove the edges of the overlaps
|
|
||||||
ol = (overlap//2) * samples_per_latent
|
|
||||||
chunk_start = 0
|
|
||||||
chunk_end = y_chunk.shape[2]
|
|
||||||
if i > 0:
|
|
||||||
# no overlap for the start of the first chunk
|
|
||||||
t_start += ol
|
|
||||||
chunk_start += ol
|
|
||||||
if i < num_chunks-1:
|
|
||||||
# no overlap for the end of the last chunk
|
|
||||||
t_end -= ol
|
|
||||||
chunk_end -= ol
|
|
||||||
# paste the chunked audio into our y_final output audio
|
|
||||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
|
||||||
return y_final
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionAutoencoder(AudioAutoencoder):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
diffusion: ConditionedDiffusionModel,
|
|
||||||
diffusion_downsampling_ratio,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.diffusion = diffusion
|
|
||||||
|
|
||||||
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
|
||||||
|
|
||||||
if self.encoder is not None:
|
|
||||||
# Shrink the initial encoder parameters to avoid saturated latents
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.encoder.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def decode(self, latents, steps=100):
|
|
||||||
|
|
||||||
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
|
||||||
|
|
||||||
if self.bottleneck is not None:
|
|
||||||
latents = self.bottleneck.decode(latents)
|
|
||||||
|
|
||||||
if self.decoder is not None:
|
|
||||||
latents = self.decoder(latents)
|
|
||||||
|
|
||||||
# Upsample latents to match diffusion length
|
|
||||||
if latents.shape[2] != upsampled_length:
|
|
||||||
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
|
||||||
|
|
||||||
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
|
||||||
from prismaudio_core.inference.sampling import sample
|
|
||||||
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
|
||||||
|
|
||||||
if self.pretransform is not None:
|
|
||||||
if self.pretransform.enable_grad:
|
|
||||||
decoded = self.pretransform.decode(decoded)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
decoded = self.pretransform.decode(decoded)
|
|
||||||
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
# AE factories
|
|
||||||
|
|
||||||
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
|
||||||
encoder_type = encoder_config.get("type", None)
|
|
||||||
assert encoder_type is not None, "Encoder type must be specified"
|
|
||||||
|
|
||||||
if encoder_type == "oobleck":
|
|
||||||
encoder = OobleckEncoder(
|
|
||||||
**encoder_config["config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif encoder_type == "seanet":
|
|
||||||
from encodec.modules import SEANetEncoder
|
|
||||||
seanet_encoder_config = encoder_config["config"]
|
|
||||||
|
|
||||||
#SEANet encoder expects strides in reverse order
|
|
||||||
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
|
||||||
encoder = SEANetEncoder(
|
|
||||||
**seanet_encoder_config
|
|
||||||
)
|
|
||||||
elif encoder_type == "dac":
|
|
||||||
dac_config = encoder_config["config"]
|
|
||||||
|
|
||||||
encoder = DACEncoderWrapper(**dac_config)
|
|
||||||
elif encoder_type == "local_attn":
|
|
||||||
from .local_attention import TransformerEncoder1D
|
|
||||||
|
|
||||||
local_attn_config = encoder_config["config"]
|
|
||||||
|
|
||||||
encoder = TransformerEncoder1D(
|
|
||||||
**local_attn_config
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown encoder type {encoder_type}")
|
|
||||||
|
|
||||||
requires_grad = encoder_config.get("requires_grad", True)
|
|
||||||
if not requires_grad:
|
|
||||||
for param in encoder.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
|
||||||
decoder_type = decoder_config.get("type", None)
|
|
||||||
assert decoder_type is not None, "Decoder type must be specified"
|
|
||||||
|
|
||||||
if decoder_type == "oobleck":
|
|
||||||
decoder = OobleckDecoder(
|
|
||||||
**decoder_config["config"]
|
|
||||||
)
|
|
||||||
elif decoder_type == "seanet":
|
|
||||||
from encodec.modules import SEANetDecoder
|
|
||||||
|
|
||||||
decoder = SEANetDecoder(
|
|
||||||
**decoder_config["config"]
|
|
||||||
)
|
|
||||||
elif decoder_type == "dac":
|
|
||||||
dac_config = decoder_config["config"]
|
|
||||||
|
|
||||||
decoder = DACDecoderWrapper(**dac_config)
|
|
||||||
elif decoder_type == "local_attn":
|
|
||||||
from .local_attention import TransformerDecoder1D
|
|
||||||
|
|
||||||
local_attn_config = decoder_config["config"]
|
|
||||||
|
|
||||||
decoder = TransformerDecoder1D(
|
|
||||||
**local_attn_config
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown decoder type {decoder_type}")
|
|
||||||
|
|
||||||
requires_grad = decoder_config.get("requires_grad", True)
|
|
||||||
if not requires_grad:
|
|
||||||
for param in decoder.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
return decoder
|
|
||||||
|
|
||||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
|
||||||
|
|
||||||
ae_config = config["model"]
|
|
||||||
|
|
||||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
|
||||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
|
||||||
|
|
||||||
bottleneck = ae_config.get("bottleneck", None)
|
|
||||||
|
|
||||||
latent_dim = ae_config.get("latent_dim", None)
|
|
||||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
|
||||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
|
||||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
|
||||||
io_channels = ae_config.get("io_channels", None)
|
|
||||||
assert io_channels is not None, "io_channels must be specified in model config"
|
|
||||||
sample_rate = config.get("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
|
||||||
|
|
||||||
in_channels = ae_config.get("in_channels", None)
|
|
||||||
out_channels = ae_config.get("out_channels", None)
|
|
||||||
|
|
||||||
pretransform = ae_config.get("pretransform", None)
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
|
||||||
|
|
||||||
if bottleneck is not None:
|
|
||||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
|
||||||
|
|
||||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
|
||||||
|
|
||||||
return AudioAutoencoder(
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
io_channels=io_channels,
|
|
||||||
latent_dim=latent_dim,
|
|
||||||
downsampling_ratio=downsampling_ratio,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
bottleneck=bottleneck,
|
|
||||||
pretransform=pretransform,
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
soft_clip=soft_clip
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_diffAE_from_config(config: Dict[str, Any]):
|
|
||||||
|
|
||||||
diffae_config = config["model"]
|
|
||||||
|
|
||||||
if "encoder" in diffae_config:
|
|
||||||
encoder = create_encoder_from_config(diffae_config["encoder"])
|
|
||||||
else:
|
|
||||||
encoder = None
|
|
||||||
|
|
||||||
if "decoder" in diffae_config:
|
|
||||||
decoder = create_decoder_from_config(diffae_config["decoder"])
|
|
||||||
else:
|
|
||||||
decoder = None
|
|
||||||
|
|
||||||
diffusion_model_type = diffae_config["diffusion"]["type"]
|
|
||||||
|
|
||||||
if diffusion_model_type == "DAU1d":
|
|
||||||
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
|
||||||
elif diffusion_model_type == "adp_1d":
|
|
||||||
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
|
||||||
elif diffusion_model_type == "dit":
|
|
||||||
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
|
||||||
|
|
||||||
latent_dim = diffae_config.get("latent_dim", None)
|
|
||||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
|
||||||
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
|
||||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
|
||||||
io_channels = diffae_config.get("io_channels", None)
|
|
||||||
assert io_channels is not None, "io_channels must be specified in model config"
|
|
||||||
sample_rate = config.get("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
|
||||||
|
|
||||||
bottleneck = diffae_config.get("bottleneck", None)
|
|
||||||
|
|
||||||
pretransform = diffae_config.get("pretransform", None)
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
|
||||||
|
|
||||||
if bottleneck is not None:
|
|
||||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
|
||||||
|
|
||||||
diffusion_downsampling_ratio = None
|
|
||||||
|
|
||||||
if diffusion_model_type == "DAU1d":
|
|
||||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
|
||||||
elif diffusion_model_type == "adp_1d":
|
|
||||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
|
||||||
elif diffusion_model_type == "dit":
|
|
||||||
diffusion_downsampling_ratio = 1
|
|
||||||
|
|
||||||
return DiffusionAutoencoder(
|
|
||||||
encoder=encoder,
|
|
||||||
decoder=decoder,
|
|
||||||
diffusion=diffusion,
|
|
||||||
io_channels=io_channels,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
latent_dim=latent_dim,
|
|
||||||
downsampling_ratio=downsampling_ratio,
|
|
||||||
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
|
||||||
bottleneck=bottleneck,
|
|
||||||
pretransform=pretransform
|
|
||||||
)
|
|
||||||
@@ -1,331 +0,0 @@
|
|||||||
from functools import reduce
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from torch.backends.cuda import sdp_kernel
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from dac.nn.layers import Snake1d
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
|
||||||
def __init__(self, main, skip=None):
|
|
||||||
super().__init__()
|
|
||||||
self.main = nn.Sequential(*main)
|
|
||||||
self.skip = skip if skip else nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.main(input) + self.skip(input)
|
|
||||||
|
|
||||||
class ResConvBlock(ResidualBlock):
|
|
||||||
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
|
||||||
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
|
||||||
super().__init__([
|
|
||||||
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
|
||||||
nn.GroupNorm(1, c_mid),
|
|
||||||
Snake1d(c_mid) if use_snake else nn.GELU(),
|
|
||||||
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
|
||||||
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
|
||||||
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
|
||||||
], skip)
|
|
||||||
|
|
||||||
class SelfAttention1d(nn.Module):
|
|
||||||
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
|
||||||
super().__init__()
|
|
||||||
assert c_in % n_head == 0
|
|
||||||
self.norm = nn.GroupNorm(1, c_in)
|
|
||||||
self.n_head = n_head
|
|
||||||
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
|
||||||
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
|
||||||
|
|
||||||
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
|
||||||
|
|
||||||
if not self.use_flash:
|
|
||||||
return
|
|
||||||
|
|
||||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
|
||||||
|
|
||||||
if device_properties.major == 8 and device_properties.minor == 0:
|
|
||||||
# Use flash attention for A100 GPUs
|
|
||||||
self.sdp_kernel_config = (True, False, False)
|
|
||||||
else:
|
|
||||||
# Don't use flash attention for other GPUs
|
|
||||||
self.sdp_kernel_config = (False, True, True)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
n, c, s = input.shape
|
|
||||||
qkv = self.qkv_proj(self.norm(input))
|
|
||||||
qkv = qkv.view(
|
|
||||||
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
|
||||||
q, k, v = qkv.chunk(3, dim=1)
|
|
||||||
scale = k.shape[3]**-0.25
|
|
||||||
|
|
||||||
if self.use_flash:
|
|
||||||
with sdp_kernel(*self.sdp_kernel_config):
|
|
||||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
|
||||||
else:
|
|
||||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
|
||||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
|
||||||
|
|
||||||
|
|
||||||
return input + self.dropout(self.out_proj(y))
|
|
||||||
|
|
||||||
class SkipBlock(nn.Module):
|
|
||||||
def __init__(self, *main):
|
|
||||||
super().__init__()
|
|
||||||
self.main = nn.Sequential(*main)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return torch.cat([self.main(input), input], dim=1)
|
|
||||||
|
|
||||||
class FourierFeatures(nn.Module):
|
|
||||||
def __init__(self, in_features, out_features, std=1.):
|
|
||||||
super().__init__()
|
|
||||||
assert out_features % 2 == 0
|
|
||||||
self.weight = nn.Parameter(torch.randn(
|
|
||||||
[out_features // 2, in_features]) * std)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
f = 2 * math.pi * input @ self.weight.T
|
|
||||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
|
||||||
|
|
||||||
def expand_to_planes(input, shape):
|
|
||||||
return input[..., None].repeat([1, 1, shape[2]])
|
|
||||||
|
|
||||||
_kernels = {
|
|
||||||
'linear':
|
|
||||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
|
||||||
'cubic':
|
|
||||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
|
||||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
|
||||||
'lanczos3':
|
|
||||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
|
||||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
|
||||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
|
||||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
|
||||||
}
|
|
||||||
|
|
||||||
class Downsample1d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor(_kernels[kernel])
|
|
||||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d)
|
|
||||||
self.channels_last = channels_last
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.channels_last:
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
x = F.conv1d(x, weight, stride=2)
|
|
||||||
if self.channels_last:
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample1d(nn.Module):
|
|
||||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
|
||||||
super().__init__()
|
|
||||||
self.pad_mode = pad_mode
|
|
||||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
|
||||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
|
||||||
self.register_buffer('kernel', kernel_1d)
|
|
||||||
self.channels_last = channels_last
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.channels_last:
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
|
||||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
|
||||||
indices = torch.arange(x.shape[1], device=x.device)
|
|
||||||
weight[indices, indices] = self.kernel.to(weight)
|
|
||||||
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
|
||||||
if self.channels_last:
|
|
||||||
x = x.permute(0, 2, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def Downsample1d_2(
|
|
||||||
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
|
||||||
) -> nn.Module:
|
|
||||||
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
|
||||||
|
|
||||||
return nn.Conv1d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=factor * kernel_multiplier + 1,
|
|
||||||
stride=factor,
|
|
||||||
padding=factor * (kernel_multiplier // 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def Upsample1d_2(
|
|
||||||
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
|
||||||
) -> nn.Module:
|
|
||||||
|
|
||||||
if factor == 1:
|
|
||||||
return nn.Conv1d(
|
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_nearest:
|
|
||||||
return nn.Sequential(
|
|
||||||
nn.Upsample(scale_factor=factor, mode="nearest"),
|
|
||||||
nn.Conv1d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=1,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return nn.ConvTranspose1d(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
kernel_size=factor * 2,
|
|
||||||
stride=factor,
|
|
||||||
padding=factor // 2 + factor % 2,
|
|
||||||
output_padding=factor % 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def zero_init(layer):
|
|
||||||
nn.init.zeros_(layer.weight)
|
|
||||||
if layer.bias is not None:
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
class AdaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, features, cond_features, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"eps={self.eps},"
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
|
||||||
|
|
||||||
def normalize(x, eps=1e-4):
|
|
||||||
dim = list(range(1, x.ndim))
|
|
||||||
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
|
||||||
alpha = np.sqrt(n.numel() / x.numel())
|
|
||||||
return x / torch.add(eps, n, alpha=alpha)
|
|
||||||
|
|
||||||
class ForcedWNConv1d(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=1):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.training:
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight.copy_(normalize(self.weight))
|
|
||||||
|
|
||||||
fan_in = self.weight[0].numel()
|
|
||||||
|
|
||||||
w = normalize(self.weight) / math.sqrt(fan_in)
|
|
||||||
|
|
||||||
return F.conv1d(x, w, padding='same')
|
|
||||||
|
|
||||||
# Kernels
|
|
||||||
|
|
||||||
use_compile = True
|
|
||||||
|
|
||||||
def compile(function, *args, **kwargs):
|
|
||||||
if not use_compile:
|
|
||||||
return function
|
|
||||||
try:
|
|
||||||
return torch.compile(function, *args, **kwargs)
|
|
||||||
except RuntimeError:
|
|
||||||
return function
|
|
||||||
|
|
||||||
|
|
||||||
@compile
|
|
||||||
def linear_geglu(x, weight, bias=None):
|
|
||||||
x = x @ weight.mT
|
|
||||||
if bias is not None:
|
|
||||||
x = x + bias
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
return x * F.gelu(gate)
|
|
||||||
|
|
||||||
|
|
||||||
@compile
|
|
||||||
def rms_norm(x, scale, eps):
|
|
||||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
|
||||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
|
||||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
|
||||||
return x * scale.to(x.dtype)
|
|
||||||
|
|
||||||
# Layers
|
|
||||||
|
|
||||||
class LinearGEGLU(nn.Linear):
|
|
||||||
def __init__(self, in_features, out_features, bias=True):
|
|
||||||
super().__init__(in_features, out_features * 2, bias=bias)
|
|
||||||
self.out_features = out_features
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return linear_geglu(x, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
if fix_scale:
|
|
||||||
self.register_buffer("scale", torch.ones(shape))
|
|
||||||
else:
|
|
||||||
self.scale = nn.Parameter(torch.ones(shape))
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm(x, self.scale, self.eps)
|
|
||||||
|
|
||||||
def snake_beta(x, alpha, beta):
|
|
||||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# snake_beta = torch.compile(snake_beta)
|
|
||||||
# except RuntimeError:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
|
||||||
# License available in LICENSES/LICENSE_NVIDIA.txt
|
|
||||||
class SnakeBeta(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
|
||||||
super(SnakeBeta, self).__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
|
|
||||||
# initialize alpha
|
|
||||||
self.alpha_logscale = alpha_logscale
|
|
||||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
|
||||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
|
||||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
|
||||||
else: # linear scale alphas initialized to ones
|
|
||||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
|
||||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
|
||||||
|
|
||||||
self.alpha.requires_grad = alpha_trainable
|
|
||||||
self.beta.requires_grad = alpha_trainable
|
|
||||||
|
|
||||||
self.no_div_by_zero = 0.000000001
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
|
||||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
|
||||||
if self.alpha_logscale:
|
|
||||||
alpha = torch.exp(alpha)
|
|
||||||
beta = torch.exp(beta)
|
|
||||||
x = snake_beta(x, alpha, beta)
|
|
||||||
|
|
||||||
return x
|
|
||||||
@@ -1,355 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from einops import rearrange
|
|
||||||
from vector_quantize_pytorch import ResidualVQ, FSQ
|
|
||||||
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
|
||||||
def __init__(self, is_discrete: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.is_discrete = is_discrete
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class DiscreteBottleneck(Bottleneck):
|
|
||||||
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
|
||||||
super().__init__(is_discrete=True)
|
|
||||||
|
|
||||||
self.num_quantizers = num_quantizers
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.tokens_id = tokens_id
|
|
||||||
|
|
||||||
def decode_tokens(self, codes, **kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class TanhBottleneck(Bottleneck):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(is_discrete=False)
|
|
||||||
self.tanh = nn.Tanh()
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
x = torch.tanh(x)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def vae_sample(mean, scale):
|
|
||||||
stdev = nn.functional.softplus(scale) + 1e-4
|
|
||||||
var = stdev * stdev
|
|
||||||
logvar = torch.log(var)
|
|
||||||
latents = torch.randn_like(mean) * stdev + mean
|
|
||||||
|
|
||||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
|
||||||
|
|
||||||
return latents, kl
|
|
||||||
|
|
||||||
class VAEBottleneck(Bottleneck):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(is_discrete=False)
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False, **kwargs):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
mean, scale = x.chunk(2, dim=1)
|
|
||||||
|
|
||||||
x, kl = vae_sample(mean, scale)
|
|
||||||
|
|
||||||
info["kl"] = kl
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def compute_mean_kernel(x, y):
|
|
||||||
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
|
||||||
return torch.exp(-kernel_input).mean()
|
|
||||||
|
|
||||||
def compute_mmd(latents):
|
|
||||||
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
|
||||||
noise = torch.randn_like(latents_reshaped)
|
|
||||||
|
|
||||||
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
|
||||||
noise_kernel = compute_mean_kernel(noise, noise)
|
|
||||||
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
|
||||||
|
|
||||||
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
|
||||||
return mmd.mean()
|
|
||||||
|
|
||||||
class WassersteinBottleneck(Bottleneck):
|
|
||||||
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
|
||||||
super().__init__(is_discrete=False)
|
|
||||||
|
|
||||||
self.noise_augment_dim = noise_augment_dim
|
|
||||||
self.bypass_mmd = bypass_mmd
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
if self.training and return_info:
|
|
||||||
if self.bypass_mmd:
|
|
||||||
mmd = torch.tensor(0.0)
|
|
||||||
else:
|
|
||||||
mmd = compute_mmd(x)
|
|
||||||
|
|
||||||
info["mmd"] = mmd
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
|
|
||||||
if self.noise_augment_dim > 0:
|
|
||||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
|
||||||
x.shape[-1]).type_as(x)
|
|
||||||
x = torch.cat([x, noise], dim=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
class L2Bottleneck(Bottleneck):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(is_discrete=False)
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
x = F.normalize(x, dim=1)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return F.normalize(x, dim=1)
|
|
||||||
|
|
||||||
class RVQBottleneck(DiscreteBottleneck):
|
|
||||||
def __init__(self, **quantizer_kwargs):
|
|
||||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
|
||||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
|
||||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False, **kwargs):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
x = rearrange(x, "b c n -> b n c")
|
|
||||||
x, indices, loss = self.quantizer(x)
|
|
||||||
x = rearrange(x, "b n c -> b c n")
|
|
||||||
|
|
||||||
info["quantizer_indices"] = indices
|
|
||||||
info["quantizer_loss"] = loss.mean()
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode_tokens(self, codes, **kwargs):
|
|
||||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
|
|
||||||
class RVQVAEBottleneck(DiscreteBottleneck):
|
|
||||||
def __init__(self, **quantizer_kwargs):
|
|
||||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
|
||||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
|
||||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
x, kl = vae_sample(*x.chunk(2, dim=1))
|
|
||||||
|
|
||||||
info["kl"] = kl
|
|
||||||
|
|
||||||
x = rearrange(x, "b c n -> b n c")
|
|
||||||
x, indices, loss = self.quantizer(x)
|
|
||||||
x = rearrange(x, "b n c -> b c n")
|
|
||||||
|
|
||||||
info["quantizer_indices"] = indices
|
|
||||||
info["quantizer_loss"] = loss.mean()
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode_tokens(self, codes, **kwargs):
|
|
||||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
|
|
||||||
class DACRVQBottleneck(DiscreteBottleneck):
|
|
||||||
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
|
||||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
|
||||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
|
||||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
|
||||||
self.quantize_on_decode = quantize_on_decode
|
|
||||||
self.noise_augment_dim = noise_augment_dim
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False, **kwargs):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
info["pre_quantizer"] = x
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
return x, info if return_info else x
|
|
||||||
|
|
||||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
|
||||||
|
|
||||||
output = {
|
|
||||||
"z": z,
|
|
||||||
"codes": codes,
|
|
||||||
"latents": latents,
|
|
||||||
"vq/commitment_loss": commitment_loss,
|
|
||||||
"vq/codebook_loss": codebook_loss,
|
|
||||||
}
|
|
||||||
|
|
||||||
output["vq/commitment_loss"] /= self.num_quantizers
|
|
||||||
output["vq/codebook_loss"] /= self.num_quantizers
|
|
||||||
|
|
||||||
info.update(output)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return output["z"], info
|
|
||||||
|
|
||||||
return output["z"]
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
x = self.quantizer(x)[0]
|
|
||||||
|
|
||||||
if self.noise_augment_dim > 0:
|
|
||||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
|
||||||
x.shape[-1]).type_as(x)
|
|
||||||
x = torch.cat([x, noise], dim=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode_tokens(self, codes, **kwargs):
|
|
||||||
latents, _, _ = self.quantizer.from_codes(codes)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
|
|
||||||
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
|
||||||
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
|
||||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
|
||||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
|
||||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
|
||||||
self.quantize_on_decode = quantize_on_decode
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False, n_quantizers: int = None):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
mean, scale = x.chunk(2, dim=1)
|
|
||||||
|
|
||||||
x, kl = vae_sample(mean, scale)
|
|
||||||
|
|
||||||
info["pre_quantizer"] = x
|
|
||||||
info["kl"] = kl
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
return x, info if return_info else x
|
|
||||||
|
|
||||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
|
||||||
|
|
||||||
output = {
|
|
||||||
"z": z,
|
|
||||||
"codes": codes,
|
|
||||||
"latents": latents,
|
|
||||||
"vq/commitment_loss": commitment_loss,
|
|
||||||
"vq/codebook_loss": codebook_loss,
|
|
||||||
}
|
|
||||||
|
|
||||||
output["vq/commitment_loss"] /= self.num_quantizers
|
|
||||||
output["vq/codebook_loss"] /= self.num_quantizers
|
|
||||||
|
|
||||||
info.update(output)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return output["z"], info
|
|
||||||
|
|
||||||
return output["z"]
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
x = self.quantizer(x)[0]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode_tokens(self, codes, **kwargs):
|
|
||||||
latents, _, _ = self.quantizer.from_codes(codes)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
|
|
||||||
class FSQBottleneck(DiscreteBottleneck):
|
|
||||||
def __init__(self, noise_augment_dim=0, **kwargs):
|
|
||||||
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
|
||||||
|
|
||||||
self.noise_augment_dim = noise_augment_dim
|
|
||||||
|
|
||||||
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
|
||||||
|
|
||||||
def encode(self, x, return_info=False):
|
|
||||||
info = {}
|
|
||||||
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
|
|
||||||
x = rearrange(x, "b c n -> b n c")
|
|
||||||
x, indices = self.quantizer(x)
|
|
||||||
x = rearrange(x, "b n c -> b c n")
|
|
||||||
|
|
||||||
x = x.to(orig_dtype)
|
|
||||||
|
|
||||||
# Reorder indices to match the expected format
|
|
||||||
indices = rearrange(indices, "b n q -> b q n")
|
|
||||||
|
|
||||||
info["quantizer_indices"] = indices
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
else:
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
|
|
||||||
if self.noise_augment_dim > 0:
|
|
||||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
|
||||||
x.shape[-1]).type_as(x)
|
|
||||||
x = torch.cat([x, noise], dim=1)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens, **kwargs):
|
|
||||||
latents = self.quantizer.indices_to_codes(tokens)
|
|
||||||
|
|
||||||
return self.decode(latents, **kwargs)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,884 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from functools import partial
|
|
||||||
import numpy as np
|
|
||||||
import typing as tp
|
|
||||||
|
|
||||||
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
|
||||||
from .conditioners import MultiConditioner
|
|
||||||
from .dit import DiffusionTransformer
|
|
||||||
from .pretransforms import Pretransform
|
|
||||||
|
|
||||||
from .adp import UNetCFG1d, UNet1d
|
|
||||||
|
|
||||||
# Lazy imports for factory functions to avoid circular imports
|
|
||||||
def _get_create_pretransform_from_config():
|
|
||||||
from prismaudio_core.factory import create_pretransform_from_config
|
|
||||||
return create_pretransform_from_config
|
|
||||||
|
|
||||||
def _get_create_multi_conditioner_from_conditioning_config():
|
|
||||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
|
||||||
return create_multi_conditioner_from_conditioning_config
|
|
||||||
|
|
||||||
class DiffusionModel(nn.Module):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x, t, **kwargs):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
class DiffusionModelWrapper(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: DiffusionModel,
|
|
||||||
io_channels,
|
|
||||||
sample_size,
|
|
||||||
sample_rate,
|
|
||||||
min_input_length,
|
|
||||||
pretransform: tp.Optional[Pretransform] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.io_channels = io_channels
|
|
||||||
self.sample_size = sample_size
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.min_input_length = min_input_length
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
self.pretransform = pretransform
|
|
||||||
else:
|
|
||||||
self.pretransform = None
|
|
||||||
|
|
||||||
def forward(self, x, t, **kwargs):
|
|
||||||
return self.model(x, t, **kwargs)
|
|
||||||
|
|
||||||
class ConditionedDiffusionModel(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
*args,
|
|
||||||
supports_cross_attention: bool = False,
|
|
||||||
supports_input_concat: bool = False,
|
|
||||||
supports_global_cond: bool = False,
|
|
||||||
supports_prepend_cond: bool = False,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.supports_cross_attention = supports_cross_attention
|
|
||||||
self.supports_input_concat = supports_input_concat
|
|
||||||
self.supports_global_cond = supports_global_cond
|
|
||||||
self.supports_prepend_cond = supports_prepend_cond
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
t: torch.Tensor,
|
|
||||||
cross_attn_cond: torch.Tensor = None,
|
|
||||||
cross_attn_mask: torch.Tensor = None,
|
|
||||||
input_concat_cond: torch.Tensor = None,
|
|
||||||
global_embed: torch.Tensor = None,
|
|
||||||
prepend_cond: torch.Tensor = None,
|
|
||||||
prepend_cond_mask: torch.Tensor = None,
|
|
||||||
cfg_scale: float = 1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = False,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
**kwargs):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
class ConditionedDiffusionModelWrapper(nn.Module):
|
|
||||||
"""
|
|
||||||
A diffusion model that takes in conditioning
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: ConditionedDiffusionModel,
|
|
||||||
conditioner: MultiConditioner,
|
|
||||||
io_channels,
|
|
||||||
sample_rate,
|
|
||||||
min_input_length: int,
|
|
||||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
|
||||||
zero_init: bool = False,
|
|
||||||
pretransform: tp.Optional[Pretransform] = None,
|
|
||||||
cross_attn_cond_ids: tp.List[str] = [],
|
|
||||||
global_cond_ids: tp.List[str] = [],
|
|
||||||
input_concat_ids: tp.List[str] = [],
|
|
||||||
prepend_cond_ids: tp.List[str] = [],
|
|
||||||
add_cond_ids: tp.List[str] = [],
|
|
||||||
sync_cond_ids: tp.List[str] = [],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.conditioner = conditioner
|
|
||||||
self.io_channels = io_channels
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.diffusion_objective = diffusion_objective
|
|
||||||
self.pretransform = pretransform
|
|
||||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
|
||||||
self.global_cond_ids = global_cond_ids
|
|
||||||
self.input_concat_ids = input_concat_ids
|
|
||||||
self.prepend_cond_ids = prepend_cond_ids
|
|
||||||
self.add_cond_ids = add_cond_ids
|
|
||||||
self.sync_cond_ids = sync_cond_ids
|
|
||||||
self.min_input_length = min_input_length
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
if zero_init is True:
|
|
||||||
self.conditioner.apply(_basic_init)
|
|
||||||
self.model.model.initialize_weights()
|
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
|
||||||
cross_attention_input = None
|
|
||||||
cross_attention_masks = None
|
|
||||||
global_cond = None
|
|
||||||
input_concat_cond = None
|
|
||||||
prepend_cond = None
|
|
||||||
prepend_cond_mask = None
|
|
||||||
add_input = None
|
|
||||||
sync_input = None
|
|
||||||
|
|
||||||
if len(self.cross_attn_cond_ids) > 0:
|
|
||||||
# Concatenate all cross-attention inputs over the sequence dimension
|
|
||||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
|
||||||
cross_attention_input = []
|
|
||||||
cross_attention_masks = []
|
|
||||||
|
|
||||||
for key in self.cross_attn_cond_ids:
|
|
||||||
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
|
||||||
|
|
||||||
# Add sequence dimension if it's not there
|
|
||||||
if len(cross_attn_in.shape) == 2:
|
|
||||||
cross_attn_in = cross_attn_in.unsqueeze(1)
|
|
||||||
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
|
||||||
|
|
||||||
cross_attention_input.append(cross_attn_in)
|
|
||||||
cross_attention_masks.append(cross_attn_mask)
|
|
||||||
|
|
||||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
|
||||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
|
||||||
|
|
||||||
if len(self.add_cond_ids) > 0:
|
|
||||||
# Concatenate all cross-attention inputs over the sequence dimension
|
|
||||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
|
||||||
add_input = []
|
|
||||||
|
|
||||||
for key in self.add_cond_ids:
|
|
||||||
add_in = conditioning_tensors[key][0]
|
|
||||||
|
|
||||||
# Add sequence dimension if it's not there
|
|
||||||
if len(add_in.shape) == 2:
|
|
||||||
add_in = add_in.unsqueeze(1)
|
|
||||||
# add_in = add_in.transpose(1,2)
|
|
||||||
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
|
|
||||||
# add_in = add_in.transpose(1,2)
|
|
||||||
add_input.append(add_in)
|
|
||||||
|
|
||||||
add_input = torch.cat(add_input, dim=2)
|
|
||||||
|
|
||||||
if len(self.sync_cond_ids) > 0:
|
|
||||||
# Concatenate all cross-attention inputs over the sequence dimension
|
|
||||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
|
||||||
sync_input = []
|
|
||||||
|
|
||||||
for key in self.sync_cond_ids:
|
|
||||||
sync_in = conditioning_tensors[key][0]
|
|
||||||
|
|
||||||
# Add sequence dimension if it's not there
|
|
||||||
if len(sync_in.shape) == 2:
|
|
||||||
sync_in = sync_in.unsqueeze(1)
|
|
||||||
sync_input.append(sync_in)
|
|
||||||
|
|
||||||
sync_input = torch.cat(sync_input, dim=2)
|
|
||||||
|
|
||||||
if len(self.global_cond_ids) > 0:
|
|
||||||
# Concatenate all global conditioning inputs over the channel dimension
|
|
||||||
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
|
||||||
global_conds = []
|
|
||||||
for key in self.global_cond_ids:
|
|
||||||
global_cond_input = conditioning_tensors[key][0]
|
|
||||||
if len(global_cond_input.shape) == 2:
|
|
||||||
global_cond_input = global_cond_input.unsqueeze(1)
|
|
||||||
global_conds.append(global_cond_input)
|
|
||||||
|
|
||||||
# # Concatenate over the channel dimension
|
|
||||||
# if global_conds[0].shape[-1] == 768:
|
|
||||||
# global_cond = torch.cat(global_conds, dim=-1)
|
|
||||||
# else:
|
|
||||||
# global_cond = sum(global_conds)
|
|
||||||
global_cond = sum(global_conds)
|
|
||||||
# global_cond = torch.cat(global_conds, dim=-1)
|
|
||||||
|
|
||||||
if len(global_cond.shape) == 3:
|
|
||||||
global_cond = global_cond.squeeze(1)
|
|
||||||
|
|
||||||
if len(self.input_concat_ids) > 0:
|
|
||||||
# Concatenate all input concat conditioning inputs over the channel dimension
|
|
||||||
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
|
||||||
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
|
||||||
|
|
||||||
if len(self.prepend_cond_ids) > 0:
|
|
||||||
# Concatenate all prepend conditioning inputs over the sequence dimension
|
|
||||||
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
|
||||||
prepend_conds = []
|
|
||||||
prepend_cond_masks = []
|
|
||||||
|
|
||||||
for key in self.prepend_cond_ids:
|
|
||||||
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
|
||||||
if len(prepend_cond_input.shape) == 2:
|
|
||||||
prepend_cond_input = prepend_cond_input.unsqueeze(1)
|
|
||||||
prepend_conds.append(prepend_cond_input)
|
|
||||||
prepend_cond_masks.append(prepend_cond_mask)
|
|
||||||
|
|
||||||
prepend_cond = torch.cat(prepend_conds, dim=1)
|
|
||||||
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
|
||||||
|
|
||||||
if negative:
|
|
||||||
return {
|
|
||||||
"negative_cross_attn_cond": cross_attention_input,
|
|
||||||
"negative_cross_attn_mask": cross_attention_masks,
|
|
||||||
"negative_global_cond": global_cond,
|
|
||||||
"negative_input_concat_cond": input_concat_cond
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"cross_attn_cond": cross_attention_input,
|
|
||||||
"cross_attn_mask": cross_attention_masks,
|
|
||||||
"global_cond": global_cond,
|
|
||||||
"input_concat_cond": input_concat_cond,
|
|
||||||
"prepend_cond": prepend_cond,
|
|
||||||
"prepend_cond_mask": prepend_cond_mask,
|
|
||||||
"add_cond": add_input,
|
|
||||||
"sync_cond": sync_input
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
|
||||||
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
|
||||||
|
|
||||||
def generate(self, *args, **kwargs):
|
|
||||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
|
||||||
return generate_diffusion_cond(self, *args, **kwargs)
|
|
||||||
|
|
||||||
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
|
||||||
|
|
||||||
self.model = UNetCFG1d(*args, **kwargs)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_mask=None,
|
|
||||||
input_concat_cond=None,
|
|
||||||
global_cond=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = False,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
negative_cross_attn_cond=None,
|
|
||||||
negative_cross_attn_mask=None,
|
|
||||||
negative_global_cond=None,
|
|
||||||
negative_input_concat_cond=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
prepend_cond_mask=None,
|
|
||||||
**kwargs):
|
|
||||||
channels_list = None
|
|
||||||
if input_concat_cond is not None:
|
|
||||||
channels_list = [input_concat_cond]
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
embedding=cross_attn_cond,
|
|
||||||
embedding_mask=cross_attn_mask,
|
|
||||||
features=global_cond,
|
|
||||||
channels_list=channels_list,
|
|
||||||
embedding_scale=cfg_scale,
|
|
||||||
embedding_mask_proba=cfg_dropout_prob,
|
|
||||||
batch_cfg=batch_cfg,
|
|
||||||
rescale_cfg=rescale_cfg,
|
|
||||||
negative_embedding=negative_cross_attn_cond,
|
|
||||||
negative_embedding_mask=negative_cross_attn_mask,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
|
||||||
|
|
||||||
self.model = UNet1d(*args, **kwargs)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
input_concat_cond=None,
|
|
||||||
global_cond=None,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_mask=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
prepend_cond_mask=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = False,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
negative_cross_attn_cond=None,
|
|
||||||
negative_cross_attn_mask=None,
|
|
||||||
negative_global_cond=None,
|
|
||||||
negative_input_concat_cond=None,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
channels_list = None
|
|
||||||
if input_concat_cond is not None:
|
|
||||||
|
|
||||||
# Interpolate input_concat_cond to the same length as x
|
|
||||||
if input_concat_cond.shape[2] != x.shape[2]:
|
|
||||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
|
||||||
|
|
||||||
channels_list = [input_concat_cond]
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
features=global_cond,
|
|
||||||
channels_list=channels_list,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
class UNet1DUncondWrapper(DiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
|
||||||
|
|
||||||
self.io_channels = in_channels
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self, x, t, **kwargs):
|
|
||||||
return self.model(x, t, **kwargs)
|
|
||||||
|
|
||||||
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
|
||||||
|
|
||||||
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
input_concat_cond=None,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_mask=None,
|
|
||||||
global_cond=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = False,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
negative_cross_attn_cond=None,
|
|
||||||
negative_cross_attn_mask=None,
|
|
||||||
negative_global_cond=None,
|
|
||||||
negative_input_concat_cond=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
return self.model(x, t, cond = input_concat_cond)
|
|
||||||
|
|
||||||
class DiffusionAttnUnet1D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
io_channels = 2,
|
|
||||||
depth=14,
|
|
||||||
n_attn_layers = 6,
|
|
||||||
channels = [128, 128, 256, 256] + [512] * 10,
|
|
||||||
cond_dim = 0,
|
|
||||||
cond_noise_aug = False,
|
|
||||||
kernel_size = 5,
|
|
||||||
learned_resample = False,
|
|
||||||
strides = [2] * 13,
|
|
||||||
conv_bias = True,
|
|
||||||
use_snake = False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.cond_noise_aug = cond_noise_aug
|
|
||||||
|
|
||||||
self.io_channels = io_channels
|
|
||||||
|
|
||||||
if self.cond_noise_aug:
|
|
||||||
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
|
||||||
|
|
||||||
self.timestep_embed = FourierFeatures(1, 16)
|
|
||||||
|
|
||||||
attn_layer = depth - n_attn_layers
|
|
||||||
|
|
||||||
strides = [1] + strides
|
|
||||||
|
|
||||||
block = nn.Identity()
|
|
||||||
|
|
||||||
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
|
||||||
|
|
||||||
for i in range(depth, 0, -1):
|
|
||||||
c = channels[i - 1]
|
|
||||||
stride = strides[i-1]
|
|
||||||
if stride > 2 and not learned_resample:
|
|
||||||
raise ValueError("Must have stride 2 without learned resampling")
|
|
||||||
|
|
||||||
if i > 1:
|
|
||||||
c_prev = channels[i - 2]
|
|
||||||
add_attn = i >= attn_layer and n_attn_layers > 0
|
|
||||||
block = SkipBlock(
|
|
||||||
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
|
||||||
conv_block(c_prev, c, c),
|
|
||||||
SelfAttention1d(
|
|
||||||
c, c // 32) if add_attn else nn.Identity(),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
SelfAttention1d(
|
|
||||||
c, c // 32) if add_attn else nn.Identity(),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
SelfAttention1d(
|
|
||||||
c, c // 32) if add_attn else nn.Identity(),
|
|
||||||
block,
|
|
||||||
conv_block(c * 2 if i != depth else c, c, c),
|
|
||||||
SelfAttention1d(
|
|
||||||
c, c // 32) if add_attn else nn.Identity(),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
SelfAttention1d(
|
|
||||||
c, c // 32) if add_attn else nn.Identity(),
|
|
||||||
conv_block(c, c, c_prev),
|
|
||||||
SelfAttention1d(c_prev, c_prev //
|
|
||||||
32) if add_attn else nn.Identity(),
|
|
||||||
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
|
||||||
block = nn.Sequential(
|
|
||||||
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
block,
|
|
||||||
conv_block(c * 2, c, c),
|
|
||||||
conv_block(c, c, c),
|
|
||||||
conv_block(c, c, io_channels, is_last=True),
|
|
||||||
)
|
|
||||||
self.net = block
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.net.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
|
||||||
|
|
||||||
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
|
||||||
|
|
||||||
inputs = [x, timestep_embed]
|
|
||||||
|
|
||||||
if cond is not None:
|
|
||||||
if cond.shape[2] != x.shape[2]:
|
|
||||||
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
|
||||||
|
|
||||||
if self.cond_noise_aug:
|
|
||||||
# Get a random number between 0 and 1, uniformly sampled
|
|
||||||
if cond_aug_scale is None:
|
|
||||||
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
|
||||||
else:
|
|
||||||
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
|
||||||
|
|
||||||
# Add noise to the conditioning signal
|
|
||||||
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
|
||||||
|
|
||||||
# Get embedding for noise cond level, reusing timestamp_embed
|
|
||||||
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
|
||||||
|
|
||||||
inputs.append(aug_level_embed)
|
|
||||||
|
|
||||||
inputs.append(cond)
|
|
||||||
|
|
||||||
outputs = self.net(torch.cat(inputs, dim=1))
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
class DiTWrapper(ConditionedDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
|
||||||
|
|
||||||
self.model = DiffusionTransformer(*args, **kwargs)
|
|
||||||
# with torch.no_grad():
|
|
||||||
# for param in self.model.parameters():
|
|
||||||
# param *= 0.5
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_mask=None,
|
|
||||||
negative_cross_attn_cond=None,
|
|
||||||
negative_cross_attn_mask=None,
|
|
||||||
input_concat_cond=None,
|
|
||||||
negative_input_concat_cond=None,
|
|
||||||
global_cond=None,
|
|
||||||
negative_global_cond=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
prepend_cond_mask=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob: float = 0.0,
|
|
||||||
batch_cfg: bool = True,
|
|
||||||
rescale_cfg: bool = False,
|
|
||||||
scale_phi: float = 0.0,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
|
||||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
|
||||||
|
|
||||||
return self.model(
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cross_attn_cond=cross_attn_cond,
|
|
||||||
cross_attn_cond_mask=cross_attn_mask,
|
|
||||||
negative_cross_attn_cond=negative_cross_attn_cond,
|
|
||||||
negative_cross_attn_mask=negative_cross_attn_mask,
|
|
||||||
input_concat_cond=input_concat_cond,
|
|
||||||
prepend_cond=prepend_cond,
|
|
||||||
prepend_cond_mask=prepend_cond_mask,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
cfg_dropout_prob=cfg_dropout_prob,
|
|
||||||
scale_phi=scale_phi,
|
|
||||||
global_embed=global_cond,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
|
||||||
"""
|
|
||||||
A diffusion model that takes in conditioning
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
conditioner: MultiConditioner,
|
|
||||||
io_channels,
|
|
||||||
sample_rate,
|
|
||||||
min_input_length: int,
|
|
||||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
|
||||||
pretransform: tp.Optional[Pretransform] = None,
|
|
||||||
cross_attn_cond_ids: tp.List[str] = [],
|
|
||||||
global_cond_ids: tp.List[str] = [],
|
|
||||||
input_concat_ids: tp.List[str] = [],
|
|
||||||
prepend_cond_ids: tp.List[str] = [],
|
|
||||||
add_cond_ids: tp.List[str] = [],
|
|
||||||
mm_cond_ids: tp.List[str] = [],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.conditioner = conditioner
|
|
||||||
self.io_channels = io_channels
|
|
||||||
self.sample_rate = sample_rate
|
|
||||||
self.diffusion_objective = diffusion_objective
|
|
||||||
self.pretransform = pretransform
|
|
||||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
|
||||||
self.global_cond_ids = global_cond_ids
|
|
||||||
self.input_concat_ids = input_concat_ids
|
|
||||||
self.prepend_cond_ids = prepend_cond_ids
|
|
||||||
self.add_cond_ids = add_cond_ids
|
|
||||||
self.min_input_length = min_input_length
|
|
||||||
self.mm_cond_ids = mm_cond_ids
|
|
||||||
|
|
||||||
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
|
|
||||||
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
|
|
||||||
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
|
|
||||||
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
|
|
||||||
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
|
|
||||||
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
|
|
||||||
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
|
|
||||||
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
|
|
||||||
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
|
|
||||||
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
|
|
||||||
|
|
||||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
|
||||||
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
|
|
||||||
cross_attention_input = None
|
|
||||||
cross_attention_masks = None
|
|
||||||
global_cond = None
|
|
||||||
input_concat_cond = None
|
|
||||||
prepend_cond = None
|
|
||||||
prepend_cond_mask = None
|
|
||||||
add_input = None
|
|
||||||
inpaint_masked_input = None
|
|
||||||
t5_features = None
|
|
||||||
metaclip_global_text_features = None
|
|
||||||
clip_f = conditioning_tensors["metaclip_features"]
|
|
||||||
sync_f = conditioning_tensors["sync_features"]
|
|
||||||
text_f = conditioning_tensors["metaclip_text_features"]
|
|
||||||
if 'inpaint_masked_input' in conditioning_tensors.keys():
|
|
||||||
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
|
|
||||||
if 't5_features' in conditioning_tensors.keys():
|
|
||||||
t5_features = conditioning_tensors["t5_features"]
|
|
||||||
if 'metaclip_global_text_features' in conditioning_tensors.keys():
|
|
||||||
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
|
|
||||||
return {
|
|
||||||
"clip_f": clip_f,
|
|
||||||
"sync_f": sync_f,
|
|
||||||
"text_f": text_f,
|
|
||||||
"inpaint_masked_input": inpaint_masked_input,
|
|
||||||
"t5_features": t5_features,
|
|
||||||
"metaclip_global_text_features": metaclip_global_text_features
|
|
||||||
}
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
|
||||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
|
||||||
|
|
||||||
def generate(self, *args, **kwargs):
|
|
||||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
|
||||||
return generate_diffusion_cond(self, *args, **kwargs)
|
|
||||||
|
|
||||||
class DiTUncondWrapper(DiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
io_channels,
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
|
|
||||||
|
|
||||||
self.io_channels = io_channels
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for param in self.model.parameters():
|
|
||||||
param *= 0.5
|
|
||||||
|
|
||||||
def forward(self, x, t, **kwargs):
|
|
||||||
return self.model(x, t, **kwargs)
|
|
||||||
|
|
||||||
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
|
||||||
diffusion_uncond_config = config["model"]
|
|
||||||
|
|
||||||
model_type = diffusion_uncond_config.get('type', None)
|
|
||||||
|
|
||||||
diffusion_config = diffusion_uncond_config.get('config', {})
|
|
||||||
|
|
||||||
assert model_type is not None, "Must specify model type in config"
|
|
||||||
|
|
||||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
|
||||||
|
|
||||||
sample_size = config.get("sample_size", None)
|
|
||||||
assert sample_size is not None, "Must specify sample size in config"
|
|
||||||
|
|
||||||
sample_rate = config.get("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "Must specify sample rate in config"
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
|
||||||
min_input_length = pretransform.downsampling_ratio
|
|
||||||
else:
|
|
||||||
min_input_length = 1
|
|
||||||
|
|
||||||
if model_type == 'DAU1d':
|
|
||||||
|
|
||||||
model = DiffusionAttnUnet1D(
|
|
||||||
**diffusion_config
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "adp_uncond_1d":
|
|
||||||
|
|
||||||
model = UNet1DUncondWrapper(
|
|
||||||
**diffusion_config
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "dit":
|
|
||||||
model = DiTUncondWrapper(
|
|
||||||
**diffusion_config
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
||||||
|
|
||||||
return DiffusionModelWrapper(model,
|
|
||||||
io_channels=model.io_channels,
|
|
||||||
sample_size=sample_size,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
pretransform=pretransform,
|
|
||||||
min_input_length=min_input_length)
|
|
||||||
|
|
||||||
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
|
||||||
diffusion_uncond_config = config["model"]
|
|
||||||
|
|
||||||
|
|
||||||
diffusion_config = diffusion_uncond_config.get('diffusion', {})
|
|
||||||
model_type = diffusion_config.get('type', None)
|
|
||||||
model_config = diffusion_config.get("config",{})
|
|
||||||
assert model_type is not None, "Must specify model type in config"
|
|
||||||
|
|
||||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
|
||||||
|
|
||||||
sample_size = config.get("sample_size", None)
|
|
||||||
assert sample_size is not None, "Must specify sample size in config"
|
|
||||||
|
|
||||||
sample_rate = config.get("sample_rate", None)
|
|
||||||
assert sample_rate is not None, "Must specify sample rate in config"
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
|
||||||
min_input_length = pretransform.downsampling_ratio
|
|
||||||
else:
|
|
||||||
min_input_length = 1
|
|
||||||
|
|
||||||
if model_type == 'DAU1d':
|
|
||||||
|
|
||||||
model = DiffusionAttnUnet1D(
|
|
||||||
**model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == "adp_uncond_1d":
|
|
||||||
|
|
||||||
io_channels = model_config.get("io_channels", 64)
|
|
||||||
model = UNet1DUncondWrapper(
|
|
||||||
io_channels = io_channels,
|
|
||||||
**model_config
|
|
||||||
)
|
|
||||||
elif model_type == "dit":
|
|
||||||
model = DiTUncondWrapper(
|
|
||||||
**model_config
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
||||||
|
|
||||||
return DiffusionModelWrapper(model,
|
|
||||||
io_channels=model.io_channels,
|
|
||||||
sample_size=sample_size,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
pretransform=pretransform,
|
|
||||||
min_input_length=min_input_length)
|
|
||||||
|
|
||||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|
||||||
|
|
||||||
model_config = config["model"]
|
|
||||||
|
|
||||||
model_type = config["model_type"]
|
|
||||||
|
|
||||||
diffusion_config = model_config.get('diffusion', None)
|
|
||||||
assert diffusion_config is not None, "Must specify diffusion config"
|
|
||||||
|
|
||||||
diffusion_model_type = diffusion_config.get('type', None)
|
|
||||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
|
||||||
|
|
||||||
diffusion_model_config = diffusion_config.get('config', None)
|
|
||||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
|
||||||
|
|
||||||
if diffusion_model_type == 'adp_cfg_1d':
|
|
||||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
|
||||||
elif diffusion_model_type == 'adp_1d':
|
|
||||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
|
||||||
elif diffusion_model_type == 'dit':
|
|
||||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
|
||||||
|
|
||||||
io_channels = model_config.get('io_channels', None)
|
|
||||||
assert io_channels is not None, "Must specify io_channels in model config"
|
|
||||||
|
|
||||||
sample_rate = config.get('sample_rate', None)
|
|
||||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
|
||||||
|
|
||||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
|
||||||
|
|
||||||
conditioning_config = model_config.get('conditioning', None)
|
|
||||||
|
|
||||||
conditioner = None
|
|
||||||
if conditioning_config is not None:
|
|
||||||
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
|
|
||||||
|
|
||||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
|
||||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
|
||||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
|
||||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
|
||||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
|
||||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
|
||||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
|
||||||
zero_init = diffusion_config.get('zero_init', False)
|
|
||||||
pretransform = model_config.get("pretransform", None)
|
|
||||||
|
|
||||||
if pretransform is not None:
|
|
||||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
|
||||||
min_input_length = pretransform.downsampling_ratio
|
|
||||||
else:
|
|
||||||
min_input_length = 1
|
|
||||||
|
|
||||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
|
||||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
|
||||||
elif diffusion_model_type == "dit":
|
|
||||||
min_input_length *= diffusion_model.model.patch_size
|
|
||||||
|
|
||||||
# Get the proper wrapper class
|
|
||||||
|
|
||||||
extra_kwargs = {}
|
|
||||||
|
|
||||||
if model_type == "mm_diffusion_cond":
|
|
||||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
|
||||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
|
||||||
|
|
||||||
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
|
||||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
|
||||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
||||||
|
|
||||||
return wrapper_fn(
|
|
||||||
diffusion_model,
|
|
||||||
conditioner,
|
|
||||||
min_input_length=min_input_length,
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
cross_attn_cond_ids=cross_attention_ids,
|
|
||||||
global_cond_ids=global_cond_ids,
|
|
||||||
input_concat_ids=input_concat_ids,
|
|
||||||
prepend_cond_ids=prepend_cond_ids,
|
|
||||||
add_cond_ids=add_cond_ids,
|
|
||||||
sync_cond_ids=sync_cond_ids,
|
|
||||||
pretransform=pretransform,
|
|
||||||
io_channels=io_channels,
|
|
||||||
zero_init=zero_init,
|
|
||||||
**extra_kwargs
|
|
||||||
)
|
|
||||||
@@ -1,539 +0,0 @@
|
|||||||
import typing as tp
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
# from beartype.typing import Tuple
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
|
||||||
from .blocks import FourierFeatures
|
|
||||||
from .transformer import ContinuousTransformer
|
|
||||||
from .utils import mask_from_frac_lengths, resample
|
|
||||||
|
|
||||||
class DiffusionTransformer(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
io_channels=32,
|
|
||||||
patch_size=1,
|
|
||||||
embed_dim=768,
|
|
||||||
cond_token_dim=0,
|
|
||||||
project_cond_tokens=True,
|
|
||||||
global_cond_dim=0,
|
|
||||||
project_global_cond=True,
|
|
||||||
input_concat_dim=0,
|
|
||||||
prepend_cond_dim=0,
|
|
||||||
cond_ctx_dim=0,
|
|
||||||
depth=12,
|
|
||||||
num_heads=8,
|
|
||||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
|
||||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
|
||||||
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
|
|
||||||
add_token_dim=0,
|
|
||||||
sync_token_dim=0,
|
|
||||||
use_mlp=False,
|
|
||||||
use_zero_init=False,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.cond_token_dim = cond_token_dim
|
|
||||||
|
|
||||||
# Timestep embeddings
|
|
||||||
timestep_features_dim = 256
|
|
||||||
# Timestep embeddings
|
|
||||||
self.timestep_cond_type = timestep_cond_type
|
|
||||||
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
|
||||||
|
|
||||||
if timestep_cond_type == "global":
|
|
||||||
timestep_embed_dim = embed_dim
|
|
||||||
elif timestep_cond_type == "input_concat":
|
|
||||||
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
|
|
||||||
input_concat_dim += timestep_embed_dim
|
|
||||||
|
|
||||||
self.to_timestep_embed = nn.Sequential(
|
|
||||||
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(embed_dim, embed_dim, bias=True),
|
|
||||||
)
|
|
||||||
self.use_mlp = use_mlp
|
|
||||||
if cond_token_dim > 0:
|
|
||||||
# Conditioning tokens
|
|
||||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
|
||||||
self.to_cond_embed = nn.Sequential(
|
|
||||||
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cond_embed_dim = 0
|
|
||||||
|
|
||||||
if global_cond_dim > 0:
|
|
||||||
# Global conditioning
|
|
||||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
|
||||||
self.to_global_embed = nn.Sequential(
|
|
||||||
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
|
||||||
)
|
|
||||||
if add_token_dim > 0:
|
|
||||||
# Conditioning tokens
|
|
||||||
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
|
|
||||||
self.to_add_embed = nn.Sequential(
|
|
||||||
nn.Linear(add_token_dim, add_embed_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
add_embed_dim = 0
|
|
||||||
|
|
||||||
if sync_token_dim > 0:
|
|
||||||
# Conditioning tokens
|
|
||||||
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
|
|
||||||
self.to_sync_embed = nn.Sequential(
|
|
||||||
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sync_embed_dim = 0
|
|
||||||
|
|
||||||
|
|
||||||
if prepend_cond_dim > 0:
|
|
||||||
# Prepend conditioning
|
|
||||||
self.to_prepend_embed = nn.Sequential(
|
|
||||||
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(embed_dim, embed_dim, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_concat_dim = input_concat_dim
|
|
||||||
|
|
||||||
dim_in = io_channels + self.input_concat_dim
|
|
||||||
|
|
||||||
self.patch_size = patch_size
|
|
||||||
|
|
||||||
# Transformer
|
|
||||||
|
|
||||||
self.transformer_type = transformer_type
|
|
||||||
|
|
||||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
|
||||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
|
||||||
self.global_cond_type = global_cond_type
|
|
||||||
if self.transformer_type == "continuous_transformer":
|
|
||||||
|
|
||||||
global_dim = None
|
|
||||||
|
|
||||||
if self.global_cond_type == "adaLN":
|
|
||||||
# The global conditioning is projected to the embed_dim already at this point
|
|
||||||
global_dim = embed_dim
|
|
||||||
|
|
||||||
self.transformer = ContinuousTransformer(
|
|
||||||
dim=embed_dim,
|
|
||||||
depth=depth,
|
|
||||||
dim_heads=embed_dim // num_heads,
|
|
||||||
dim_in=dim_in * patch_size,
|
|
||||||
dim_out=io_channels * patch_size,
|
|
||||||
cross_attend = cond_token_dim > 0,
|
|
||||||
cond_token_dim = cond_embed_dim,
|
|
||||||
global_cond_dim=global_dim,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
|
||||||
|
|
||||||
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
|
||||||
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
|
||||||
nn.init.zeros_(self.preprocess_conv.weight)
|
|
||||||
nn.init.zeros_(self.postprocess_conv.weight)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(self):
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
# if isinstance(module, nn.Conv1d):
|
|
||||||
# if module.bias is not None:
|
|
||||||
# nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
# Initialize timestep embedding MLP:
|
|
||||||
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
|
|
||||||
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
|
|
||||||
|
|
||||||
# Zero-out output layers:
|
|
||||||
if self.global_cond_type == "adaLN":
|
|
||||||
for block in self.transformer.layers:
|
|
||||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
|
||||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
|
||||||
|
|
||||||
nn.init.constant_(self.empty_clip_feat, 0)
|
|
||||||
nn.init.constant_(self.empty_sync_feat, 0)
|
|
||||||
|
|
||||||
def _forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
mask=None,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_cond_mask=None,
|
|
||||||
input_concat_cond=None,
|
|
||||||
global_embed=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
prepend_cond_mask=None,
|
|
||||||
add_cond=None,
|
|
||||||
add_masks=None,
|
|
||||||
sync_cond=None,
|
|
||||||
return_info=False,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
if cross_attn_cond is not None:
|
|
||||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
|
||||||
|
|
||||||
if global_embed is not None:
|
|
||||||
# Project the global conditioning to the embedding dimension
|
|
||||||
global_embed = self.to_global_embed(global_embed)
|
|
||||||
|
|
||||||
prepend_inputs = None
|
|
||||||
prepend_mask = None
|
|
||||||
prepend_length = 0
|
|
||||||
if prepend_cond is not None:
|
|
||||||
# Project the prepend conditioning to the embedding dimension
|
|
||||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
|
||||||
|
|
||||||
prepend_inputs = prepend_cond
|
|
||||||
if prepend_cond_mask is not None:
|
|
||||||
prepend_mask = prepend_cond_mask
|
|
||||||
|
|
||||||
if input_concat_cond is not None:
|
|
||||||
# reshape from (b, n, c) to (b, c, n)
|
|
||||||
if input_concat_cond.shape[1] != x.shape[1]:
|
|
||||||
input_concat_cond = input_concat_cond.transpose(1,2)
|
|
||||||
# Interpolate input_concat_cond to the same length as x
|
|
||||||
# if input_concat_cond.shape[1] != x.shape[2]:
|
|
||||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
|
||||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
|
||||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
|
||||||
# if len(global_embed.shape) == 2:
|
|
||||||
# global_embed = global_embed.unsqueeze(1)
|
|
||||||
# global_embed = global_embed + input_concat_cond
|
|
||||||
x = torch.cat([x, input_concat_cond], dim=1)
|
|
||||||
|
|
||||||
# Get the batch of timestep embeddings
|
|
||||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
|
||||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
|
||||||
if self.timestep_cond_type == "global":
|
|
||||||
if global_embed is not None:
|
|
||||||
if len(global_embed.shape) == 3:
|
|
||||||
timestep_embed = timestep_embed.unsqueeze(1)
|
|
||||||
global_embed = global_embed + timestep_embed
|
|
||||||
else:
|
|
||||||
global_embed = timestep_embed
|
|
||||||
elif self.timestep_cond_type == "input_concat":
|
|
||||||
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
|
|
||||||
|
|
||||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
|
||||||
if self.global_cond_type == "prepend" and global_embed is not None:
|
|
||||||
if prepend_inputs is None:
|
|
||||||
# Prepend inputs are just the global embed, and the mask is all ones
|
|
||||||
if len(global_embed.shape) == 2:
|
|
||||||
prepend_inputs = global_embed.unsqueeze(1)
|
|
||||||
else:
|
|
||||||
prepend_inputs = global_embed
|
|
||||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
|
||||||
else:
|
|
||||||
# Prepend inputs are the prepend conditioning + the global embed
|
|
||||||
if len(global_embed.shape) == 2:
|
|
||||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
|
||||||
else:
|
|
||||||
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
|
|
||||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
|
||||||
|
|
||||||
prepend_length = prepend_inputs.shape[1]
|
|
||||||
|
|
||||||
x = self.preprocess_conv(x) + x
|
|
||||||
x = rearrange(x, "b c t -> b t c")
|
|
||||||
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
if self.global_cond_type == "adaLN":
|
|
||||||
extra_args["global_cond"] = global_embed
|
|
||||||
|
|
||||||
if self.patch_size > 1:
|
|
||||||
b, seq_len, c = x.shape
|
|
||||||
|
|
||||||
# 计算需要填充的数量
|
|
||||||
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
|
|
||||||
|
|
||||||
if pad_amount > 0:
|
|
||||||
# 在时间维度上进行填充
|
|
||||||
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
|
|
||||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
|
||||||
|
|
||||||
if add_cond is not None:
|
|
||||||
# Interpolate add_cond to the same length as x
|
|
||||||
# if self.use_mlp:
|
|
||||||
add_cond = self.to_add_embed(add_cond)
|
|
||||||
if add_cond.shape[1] != x.shape[1]:
|
|
||||||
add_cond = add_cond.transpose(1,2)
|
|
||||||
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
|
|
||||||
add_cond = add_cond.transpose(1,2)
|
|
||||||
# add_cond = resample(add_cond, x)
|
|
||||||
|
|
||||||
if sync_cond is not None:
|
|
||||||
sync_cond = self.to_sync_embed(sync_cond)
|
|
||||||
|
|
||||||
if self.transformer_type == "continuous_transformer":
|
|
||||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
output, info = output
|
|
||||||
|
|
||||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
|
||||||
|
|
||||||
if self.patch_size > 1:
|
|
||||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
|
||||||
# 移除之前添加的填充
|
|
||||||
if pad_amount > 0:
|
|
||||||
output = output[:, :, :seq_len]
|
|
||||||
|
|
||||||
output = self.postprocess_conv(output) + output
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return output, info
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cross_attn_cond=None,
|
|
||||||
cross_attn_cond_mask=None,
|
|
||||||
negative_cross_attn_cond=None,
|
|
||||||
negative_cross_attn_mask=None,
|
|
||||||
input_concat_cond=None,
|
|
||||||
global_embed=None,
|
|
||||||
negative_global_embed=None,
|
|
||||||
prepend_cond=None,
|
|
||||||
prepend_cond_mask=None,
|
|
||||||
add_cond=None,
|
|
||||||
sync_cond=None,
|
|
||||||
cfg_scale=1.0,
|
|
||||||
cfg_dropout_prob=0.0,
|
|
||||||
causal=False,
|
|
||||||
scale_phi=0.0,
|
|
||||||
mask=None,
|
|
||||||
return_info=False,
|
|
||||||
**kwargs):
|
|
||||||
|
|
||||||
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
|
||||||
bsz, a, b = x.shape
|
|
||||||
model_dtype = next(self.parameters()).dtype
|
|
||||||
x = x.to(model_dtype)
|
|
||||||
t = t.to(model_dtype)
|
|
||||||
|
|
||||||
if cross_attn_cond is not None:
|
|
||||||
cross_attn_cond = cross_attn_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if negative_cross_attn_cond is not None:
|
|
||||||
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if input_concat_cond is not None:
|
|
||||||
input_concat_cond = input_concat_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if global_embed is not None:
|
|
||||||
global_embed = global_embed.to(model_dtype)
|
|
||||||
|
|
||||||
if negative_global_embed is not None:
|
|
||||||
negative_global_embed = negative_global_embed.to(model_dtype)
|
|
||||||
|
|
||||||
if prepend_cond is not None:
|
|
||||||
prepend_cond = prepend_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if add_cond is not None:
|
|
||||||
add_cond = add_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if sync_cond is not None:
|
|
||||||
sync_cond = sync_cond.to(model_dtype)
|
|
||||||
|
|
||||||
if cross_attn_cond_mask is not None:
|
|
||||||
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
|
||||||
|
|
||||||
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
|
||||||
|
|
||||||
if prepend_cond_mask is not None:
|
|
||||||
prepend_cond_mask = prepend_cond_mask.bool()
|
|
||||||
|
|
||||||
|
|
||||||
# CFG dropout
|
|
||||||
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
|
|
||||||
if cross_attn_cond is not None:
|
|
||||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
|
||||||
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
|
||||||
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
|
||||||
|
|
||||||
if prepend_cond is not None:
|
|
||||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
|
||||||
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
|
||||||
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
|
||||||
|
|
||||||
if add_cond is not None:
|
|
||||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
|
||||||
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
|
|
||||||
add_cond = torch.where(dropout_mask, null_embed, add_cond)
|
|
||||||
|
|
||||||
if sync_cond is not None:
|
|
||||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
|
||||||
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
|
|
||||||
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
|
|
||||||
|
|
||||||
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
|
|
||||||
# Classifier-free guidance
|
|
||||||
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
|
||||||
batch_inputs = torch.cat([x, x], dim=0)
|
|
||||||
batch_timestep = torch.cat([t, t], dim=0)
|
|
||||||
if global_embed is not None and global_embed.shape[0] == bsz:
|
|
||||||
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
|
||||||
elif global_embed is not None:
|
|
||||||
batch_global_cond = global_embed
|
|
||||||
else:
|
|
||||||
batch_global_cond = None
|
|
||||||
|
|
||||||
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
|
|
||||||
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
|
||||||
elif input_concat_cond is not None:
|
|
||||||
batch_input_concat_cond = input_concat_cond
|
|
||||||
else:
|
|
||||||
batch_input_concat_cond = None
|
|
||||||
|
|
||||||
batch_cond = None
|
|
||||||
batch_cond_masks = None
|
|
||||||
|
|
||||||
# Handle CFG for cross-attention conditioning
|
|
||||||
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
|
|
||||||
|
|
||||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
|
||||||
|
|
||||||
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
|
||||||
if negative_cross_attn_cond is not None:
|
|
||||||
|
|
||||||
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
|
||||||
if negative_cross_attn_mask is not None:
|
|
||||||
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
|
||||||
|
|
||||||
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
|
||||||
|
|
||||||
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
|
||||||
|
|
||||||
else:
|
|
||||||
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
|
||||||
|
|
||||||
if cross_attn_cond_mask is not None:
|
|
||||||
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
|
||||||
elif cross_attn_cond is not None:
|
|
||||||
batch_cond = cross_attn_cond
|
|
||||||
else:
|
|
||||||
batch_cond = None
|
|
||||||
|
|
||||||
batch_prepend_cond = None
|
|
||||||
batch_prepend_cond_mask = None
|
|
||||||
|
|
||||||
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
|
|
||||||
|
|
||||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
|
||||||
|
|
||||||
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
|
||||||
|
|
||||||
if prepend_cond_mask is not None:
|
|
||||||
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
|
||||||
elif prepend_cond is not None:
|
|
||||||
batch_prepend_cond = prepend_cond
|
|
||||||
else:
|
|
||||||
batch_prepend_cond = None
|
|
||||||
|
|
||||||
batch_add_cond = None
|
|
||||||
|
|
||||||
# Handle CFG for cross-attention conditioning
|
|
||||||
if add_cond is not None and add_cond.shape[0] == bsz:
|
|
||||||
|
|
||||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
|
||||||
|
|
||||||
|
|
||||||
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
|
|
||||||
elif add_cond is not None:
|
|
||||||
batch_add_cond = add_cond
|
|
||||||
else:
|
|
||||||
batch_add_cond = None
|
|
||||||
|
|
||||||
batch_sync_cond = None
|
|
||||||
|
|
||||||
# Handle CFG for cross-attention conditioning
|
|
||||||
if sync_cond is not None and sync_cond.shape[0] == bsz:
|
|
||||||
|
|
||||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
|
||||||
|
|
||||||
|
|
||||||
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
|
|
||||||
elif sync_cond is not None:
|
|
||||||
batch_sync_cond = sync_cond
|
|
||||||
else:
|
|
||||||
batch_sync_cond = None
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
batch_masks = torch.cat([mask, mask], dim=0)
|
|
||||||
else:
|
|
||||||
batch_masks = None
|
|
||||||
|
|
||||||
batch_output = self._forward(
|
|
||||||
batch_inputs,
|
|
||||||
batch_timestep,
|
|
||||||
cross_attn_cond=batch_cond,
|
|
||||||
cross_attn_cond_mask=batch_cond_masks,
|
|
||||||
mask = batch_masks,
|
|
||||||
input_concat_cond=batch_input_concat_cond,
|
|
||||||
global_embed = batch_global_cond,
|
|
||||||
prepend_cond = batch_prepend_cond,
|
|
||||||
prepend_cond_mask = batch_prepend_cond_mask,
|
|
||||||
add_cond = batch_add_cond,
|
|
||||||
sync_cond = batch_sync_cond,
|
|
||||||
return_info = return_info,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
batch_output, info = batch_output
|
|
||||||
|
|
||||||
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
|
||||||
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
|
||||||
|
|
||||||
# CFG Rescale
|
|
||||||
if scale_phi != 0.0:
|
|
||||||
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
|
||||||
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
|
||||||
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
|
||||||
else:
|
|
||||||
output = cfg_output
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return output, info
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
else:
|
|
||||||
return self._forward(
|
|
||||||
x,
|
|
||||||
t,
|
|
||||||
cross_attn_cond=cross_attn_cond,
|
|
||||||
cross_attn_cond_mask=cross_attn_cond_mask,
|
|
||||||
input_concat_cond=input_concat_cond,
|
|
||||||
global_embed=global_embed,
|
|
||||||
prepend_cond=prepend_cond,
|
|
||||||
prepend_cond_mask=prepend_cond_mask,
|
|
||||||
add_cond=add_cond,
|
|
||||||
sync_cond=sync_cond,
|
|
||||||
mask=mask,
|
|
||||||
return_info=return_info,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
@@ -1,275 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from .blocks import AdaRMSNorm
|
|
||||||
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
|
||||||
from .utils import checkpoint
|
|
||||||
|
|
||||||
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
|
||||||
class ContinuousLocalTransformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
dim,
|
|
||||||
depth,
|
|
||||||
dim_in = None,
|
|
||||||
dim_out = None,
|
|
||||||
causal = False,
|
|
||||||
local_attn_window_size = 64,
|
|
||||||
heads = 8,
|
|
||||||
ff_mult = 2,
|
|
||||||
cond_dim = 0,
|
|
||||||
cross_attn_cond_dim = 0,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
dim_head = dim//heads
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
|
||||||
|
|
||||||
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
|
||||||
|
|
||||||
self.local_attn_window_size = local_attn_window_size
|
|
||||||
|
|
||||||
self.cond_dim = cond_dim
|
|
||||||
|
|
||||||
self.cross_attn_cond_dim = cross_attn_cond_dim
|
|
||||||
|
|
||||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
|
||||||
|
|
||||||
for _ in range(depth):
|
|
||||||
|
|
||||||
self.layers.append(nn.ModuleList([
|
|
||||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
|
||||||
Attention(
|
|
||||||
dim=dim,
|
|
||||||
dim_heads=dim_head,
|
|
||||||
causal=causal,
|
|
||||||
zero_init_output=True,
|
|
||||||
natten_kernel_size=local_attn_window_size,
|
|
||||||
),
|
|
||||||
Attention(
|
|
||||||
dim=dim,
|
|
||||||
dim_heads=dim_head,
|
|
||||||
dim_context = cross_attn_cond_dim,
|
|
||||||
zero_init_output=True
|
|
||||||
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
|
||||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
|
||||||
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
|
||||||
]))
|
|
||||||
|
|
||||||
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
|
||||||
|
|
||||||
x = checkpoint(self.project_in, x)
|
|
||||||
|
|
||||||
if prepend_cond is not None:
|
|
||||||
x = torch.cat([prepend_cond, x], dim=1)
|
|
||||||
|
|
||||||
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
|
||||||
|
|
||||||
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if cond is not None:
|
|
||||||
x = checkpoint(attn_norm, x, cond)
|
|
||||||
else:
|
|
||||||
x = checkpoint(attn_norm, x)
|
|
||||||
|
|
||||||
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
|
||||||
|
|
||||||
if cross_attn_cond is not None:
|
|
||||||
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
|
|
||||||
if cond is not None:
|
|
||||||
x = checkpoint(ff_norm, x, cond)
|
|
||||||
else:
|
|
||||||
x = checkpoint(ff_norm, x)
|
|
||||||
|
|
||||||
x = checkpoint(ff, x) + residual
|
|
||||||
|
|
||||||
return checkpoint(self.project_out, x)
|
|
||||||
|
|
||||||
class TransformerDownsampleBlock1D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
embed_dim = 768,
|
|
||||||
depth = 3,
|
|
||||||
heads = 12,
|
|
||||||
downsample_ratio = 2,
|
|
||||||
local_attn_window_size = 64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.downsample_ratio = downsample_ratio
|
|
||||||
|
|
||||||
self.transformer = ContinuousLocalTransformer(
|
|
||||||
dim=embed_dim,
|
|
||||||
depth=depth,
|
|
||||||
heads=heads,
|
|
||||||
local_attn_window_size=local_attn_window_size,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
|
||||||
|
|
||||||
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
x = checkpoint(self.project_in, x)
|
|
||||||
|
|
||||||
# Compute
|
|
||||||
x = self.transformer(x)
|
|
||||||
|
|
||||||
# Trade sequence length for channels
|
|
||||||
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
|
||||||
|
|
||||||
# Project back to embed dim
|
|
||||||
x = checkpoint(self.project_down, x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TransformerUpsampleBlock1D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
embed_dim,
|
|
||||||
depth = 3,
|
|
||||||
heads = 12,
|
|
||||||
upsample_ratio = 2,
|
|
||||||
local_attn_window_size = 64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.upsample_ratio = upsample_ratio
|
|
||||||
|
|
||||||
self.transformer = ContinuousLocalTransformer(
|
|
||||||
dim=embed_dim,
|
|
||||||
depth=depth,
|
|
||||||
heads=heads,
|
|
||||||
local_attn_window_size = local_attn_window_size,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
|
||||||
|
|
||||||
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
# Project to embed dim
|
|
||||||
x = checkpoint(self.project_in, x)
|
|
||||||
|
|
||||||
# Project to increase channel dim
|
|
||||||
x = checkpoint(self.project_up, x)
|
|
||||||
|
|
||||||
# Trade channels for sequence length
|
|
||||||
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
|
||||||
|
|
||||||
# Compute
|
|
||||||
x = self.transformer(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder1D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
embed_dims = [96, 192, 384, 768],
|
|
||||||
heads = [12, 12, 12, 12],
|
|
||||||
depths = [3, 3, 3, 3],
|
|
||||||
ratios = [2, 2, 2, 2],
|
|
||||||
local_attn_window_size = 64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
for layer in range(len(depths)):
|
|
||||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
TransformerDownsampleBlock1D(
|
|
||||||
in_channels = prev_dim,
|
|
||||||
embed_dim = embed_dims[layer],
|
|
||||||
heads = heads[layer],
|
|
||||||
depth = depths[layer],
|
|
||||||
downsample_ratio = ratios[layer],
|
|
||||||
local_attn_window_size = local_attn_window_size,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
|
||||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = rearrange(x, "b c n -> b n c")
|
|
||||||
x = checkpoint(self.project_in, x)
|
|
||||||
x = self.layers(x)
|
|
||||||
x = checkpoint(self.project_out, x)
|
|
||||||
x = rearrange(x, "b n c -> b c n")
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder1D(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
embed_dims = [768, 384, 192, 96],
|
|
||||||
heads = [12, 12, 12, 12],
|
|
||||||
depths = [3, 3, 3, 3],
|
|
||||||
ratios = [2, 2, 2, 2],
|
|
||||||
local_attn_window_size = 64,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
for layer in range(len(depths)):
|
|
||||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
|
||||||
|
|
||||||
layers.append(
|
|
||||||
TransformerUpsampleBlock1D(
|
|
||||||
in_channels = prev_dim,
|
|
||||||
embed_dim = embed_dims[layer],
|
|
||||||
heads = heads[layer],
|
|
||||||
depth = depths[layer],
|
|
||||||
upsample_ratio = ratios[layer],
|
|
||||||
local_attn_window_size = local_attn_window_size,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
|
||||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = rearrange(x, "b c n -> b n c")
|
|
||||||
x = checkpoint(self.project_in, x)
|
|
||||||
x = self.layers(x)
|
|
||||||
x = checkpoint(self.project_out, x)
|
|
||||||
x = rearrange(x, "b n c -> b c n")
|
|
||||||
return x
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# mmmodules package
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
# mmmodules.model package
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from einops import rearrange
|
|
||||||
from scipy.optimize import fmin
|
|
||||||
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
|
||||||
|
|
||||||
class PQMF(nn.Module):
|
|
||||||
"""
|
|
||||||
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
|
||||||
Uses polyphase representation which is computationally more efficient for real-time.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
|
||||||
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, attenuation, num_bands):
|
|
||||||
super(PQMF, self).__init__()
|
|
||||||
|
|
||||||
# Ensure num_bands is a power of 2
|
|
||||||
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
|
||||||
assert is_power_of_2, "'num_bands' must be a power of 2."
|
|
||||||
|
|
||||||
# Create the prototype filter
|
|
||||||
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
|
||||||
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
|
||||||
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
|
||||||
|
|
||||||
# Register filters and settings
|
|
||||||
self.register_buffer("filter_bank", padded_filter_bank)
|
|
||||||
self.register_buffer("prototype", prototype_filter)
|
|
||||||
self.num_bands = num_bands
|
|
||||||
|
|
||||||
def forward(self, signal):
|
|
||||||
"""Decompose the signal into multiple frequency bands."""
|
|
||||||
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
|
||||||
signal = prepare_signal_dimensions(signal)
|
|
||||||
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
|
||||||
signal = pad_signal(signal, self.num_bands)
|
|
||||||
# run it
|
|
||||||
signal = polyphase_analysis(signal, self.filter_bank)
|
|
||||||
return apply_alias_cancellation(signal)
|
|
||||||
|
|
||||||
def inverse(self, bands):
|
|
||||||
"""Reconstruct the original signal from the frequency bands."""
|
|
||||||
bands = apply_alias_cancellation(bands)
|
|
||||||
return polyphase_synthesis(bands, self.filter_bank)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_signal_dimensions(signal):
|
|
||||||
"""
|
|
||||||
Rearrange signal into Batch x Channels x Length.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
signal : torch.Tensor or numpy.ndarray
|
|
||||||
The input signal.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Preprocessed signal tensor.
|
|
||||||
"""
|
|
||||||
# Convert numpy to torch tensor
|
|
||||||
if isinstance(signal, np.ndarray):
|
|
||||||
signal = torch.from_numpy(signal)
|
|
||||||
|
|
||||||
# Ensure tensor
|
|
||||||
if not isinstance(signal, torch.Tensor):
|
|
||||||
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
|
||||||
|
|
||||||
# Modify dimension of signal to Batch x Channels x Length
|
|
||||||
if signal.dim() == 1:
|
|
||||||
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
|
||||||
signal = signal.unsqueeze(0).unsqueeze(0)
|
|
||||||
elif signal.dim() == 2:
|
|
||||||
# This is a multi-channel signal (e.g. stereo)
|
|
||||||
# Rearrange so that larger dimension (Length) is last
|
|
||||||
if signal.shape[0] > signal.shape[1]:
|
|
||||||
signal = signal.T
|
|
||||||
# Unsqueeze to 1 x Channels x Length
|
|
||||||
signal = signal.unsqueeze(0)
|
|
||||||
return signal
|
|
||||||
|
|
||||||
def pad_signal(signal, num_bands):
|
|
||||||
"""
|
|
||||||
Pads the signal to make its length divisible by the given number of bands.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
signal : torch.Tensor
|
|
||||||
The input signal tensor, where the last dimension represents the signal length.
|
|
||||||
|
|
||||||
num_bands : int
|
|
||||||
The number of bands by which the signal length should be divisible.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
The padded signal tensor. If the original signal length was already divisible
|
|
||||||
by num_bands, returns the original signal unchanged.
|
|
||||||
"""
|
|
||||||
remainder = signal.shape[-1] % num_bands
|
|
||||||
if remainder > 0:
|
|
||||||
padding_size = num_bands - remainder
|
|
||||||
signal = nn.functional.pad(signal, (0, padding_size))
|
|
||||||
return signal
|
|
||||||
|
|
||||||
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
|
||||||
"""
|
|
||||||
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
prototype_filter : torch.Tensor
|
|
||||||
The prototype filter used as the basis for modulation.
|
|
||||||
num_bands : int
|
|
||||||
The number of desired subbands or filters.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
A bank of cosine modulated filters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Initialize indices for modulation.
|
|
||||||
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
|
||||||
|
|
||||||
# Calculate the length of the prototype filter.
|
|
||||||
filter_length = prototype_filter.shape[-1]
|
|
||||||
|
|
||||||
# Generate symmetric time indices centered around zero.
|
|
||||||
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
|
||||||
|
|
||||||
# Calculate phase offsets to ensure orthogonality between subbands.
|
|
||||||
phase_offsets = (-1)**subband_indices * np.pi / 4
|
|
||||||
|
|
||||||
# Compute the cosine modulation function.
|
|
||||||
modulation = torch.cos(
|
|
||||||
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply modulation to the prototype filter.
|
|
||||||
modulated_filters = 2 * prototype_filter * modulation
|
|
||||||
|
|
||||||
return modulated_filters
|
|
||||||
|
|
||||||
|
|
||||||
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
|
||||||
"""
|
|
||||||
Design a lowpass filter using the Kaiser window.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
angular_cutoff : float
|
|
||||||
The angular frequency cutoff of the filter.
|
|
||||||
attenuation : float
|
|
||||||
The desired stopband attenuation in decibels (dB).
|
|
||||||
filter_length : int, optional
|
|
||||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ndarray
|
|
||||||
The designed lowpass filter coefficients.
|
|
||||||
"""
|
|
||||||
|
|
||||||
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
|
||||||
|
|
||||||
# Ensure the estimated length is odd.
|
|
||||||
estimated_length = 2 * (estimated_length // 2) + 1
|
|
||||||
|
|
||||||
if filter_length is None:
|
|
||||||
filter_length = estimated_length
|
|
||||||
|
|
||||||
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
|
||||||
"""
|
|
||||||
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
angular_cutoff : float
|
|
||||||
Angular frequency cutoff of the filter.
|
|
||||||
attenuation : float
|
|
||||||
Desired stopband attenuation in dB.
|
|
||||||
num_bands : int
|
|
||||||
Number of bands for the multiband filter system.
|
|
||||||
filter_length : int, optional
|
|
||||||
Desired length of the filter.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
The computed objective (loss) value for the given filter specs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
|
||||||
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
|
||||||
|
|
||||||
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
|
||||||
|
|
||||||
|
|
||||||
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
|
||||||
"""
|
|
||||||
Design the optimal prototype filter for a multiband system given the desired specs.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
attenuation : float
|
|
||||||
The desired stopband attenuation in dB.
|
|
||||||
num_bands : int
|
|
||||||
Number of bands for the multiband filter system.
|
|
||||||
filter_length : int, optional
|
|
||||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
ndarray
|
|
||||||
The optimal prototype filter coefficients.
|
|
||||||
"""
|
|
||||||
|
|
||||||
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
|
||||||
1 / num_bands, disp=0)[0]
|
|
||||||
|
|
||||||
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
|
||||||
return torch.tensor(prototype_filter, dtype=torch.float32)
|
|
||||||
|
|
||||||
def pad_to_nearest_power_of_two(x):
|
|
||||||
"""
|
|
||||||
Pads the input tensor 'x' on both sides such that its last dimension
|
|
||||||
becomes the nearest larger power of two.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
x : torch.Tensor
|
|
||||||
The input tensor to be padded.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
torch.Tensor
|
|
||||||
The padded tensor.
|
|
||||||
"""
|
|
||||||
current_length = x.shape[-1]
|
|
||||||
target_length = 2**math.ceil(math.log2(current_length))
|
|
||||||
|
|
||||||
total_padding = target_length - current_length
|
|
||||||
left_padding = total_padding // 2
|
|
||||||
right_padding = total_padding - left_padding
|
|
||||||
|
|
||||||
return nn.functional.pad(x, (left_padding, right_padding))
|
|
||||||
|
|
||||||
def apply_alias_cancellation(x):
|
|
||||||
"""
|
|
||||||
Applies alias cancellation by inverting the sign of every
|
|
||||||
second element of every second row, starting from the second
|
|
||||||
row's first element in a tensor.
|
|
||||||
|
|
||||||
This operation helps ensure that the aliasing introduced in
|
|
||||||
each band during the decomposition will be counteracted during
|
|
||||||
the reconstruction.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
x : torch.Tensor
|
|
||||||
The input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
torch.Tensor
|
|
||||||
Tensor with specific elements' sign inverted for alias cancellation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create a mask of the same shape as 'x', initialized with all ones
|
|
||||||
mask = torch.ones_like(x)
|
|
||||||
|
|
||||||
# Update specific elements in the mask to -1 to perform inversion
|
|
||||||
mask[..., 1::2, ::2] = -1
|
|
||||||
|
|
||||||
# Apply the mask to the input tensor 'x'
|
|
||||||
return x * mask
|
|
||||||
|
|
||||||
def ensure_odd_length(tensor):
|
|
||||||
"""
|
|
||||||
Pads the last dimension of a tensor to ensure its size is odd.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
tensor : torch.Tensor
|
|
||||||
Input tensor whose last dimension might need padding.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
torch.Tensor
|
|
||||||
The original tensor if its last dimension was already odd,
|
|
||||||
or the padded tensor with an odd-sized last dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
last_dim_size = tensor.shape[-1]
|
|
||||||
|
|
||||||
if last_dim_size % 2 == 0:
|
|
||||||
tensor = nn.functional.pad(tensor, (0, 1))
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def polyphase_analysis(signal, filter_bank):
|
|
||||||
"""
|
|
||||||
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
-----------
|
|
||||||
signal : torch.Tensor
|
|
||||||
Input signal tensor with shape (Batch x Channels x Length).
|
|
||||||
|
|
||||||
filter_bank : torch.Tensor
|
|
||||||
Filter bank tensor with shape (Bands x Length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
--------
|
|
||||||
torch.Tensor
|
|
||||||
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_bands = filter_bank.shape[0]
|
|
||||||
num_channels = signal.shape[1]
|
|
||||||
|
|
||||||
# Rearrange signal for polyphase processing.
|
|
||||||
# Also combine Batch x Channel into one dimension for now.
|
|
||||||
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
|
||||||
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
|
||||||
|
|
||||||
# Rearrange the filter bank for matching signal shape
|
|
||||||
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
|
||||||
|
|
||||||
# Apply convolution with appropriate padding to maintain spatial dimensions
|
|
||||||
padding = filter_bank.shape[-1] // 2
|
|
||||||
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
|
||||||
|
|
||||||
# Truncate the last dimension post-convolution to adjust the output shape
|
|
||||||
filtered_signal = filtered_signal[..., :-1]
|
|
||||||
# Rearrange the first dimension back into Batch x Channels
|
|
||||||
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
|
||||||
|
|
||||||
return filtered_signal
|
|
||||||
|
|
||||||
def polyphase_synthesis(signal, filter_bank):
|
|
||||||
"""
|
|
||||||
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
signal : torch.Tensor
|
|
||||||
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
|
||||||
|
|
||||||
filter_bank : torch.Tensor
|
|
||||||
Analysis filter bank (shape: Bands x Length).
|
|
||||||
|
|
||||||
should_rearrange : bool, optional
|
|
||||||
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Reconstructed signal (shape: Batch x Channels X Length)
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_bands = filter_bank.shape[0]
|
|
||||||
num_channels = signal.shape[1]
|
|
||||||
|
|
||||||
# Rearrange the filter bank
|
|
||||||
filter_bank = filter_bank.flip(-1)
|
|
||||||
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
|
||||||
|
|
||||||
# Combine Batch x Channels into one dimension for now.
|
|
||||||
signal = rearrange(signal, "b c n t -> (b c) n t")
|
|
||||||
|
|
||||||
# Apply convolution with appropriate padding
|
|
||||||
padding_amount = filter_bank.shape[-1] // 2 + 1
|
|
||||||
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
|
||||||
|
|
||||||
# Scale the result
|
|
||||||
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
|
||||||
|
|
||||||
# Reorganize the output and truncate
|
|
||||||
reconstructed_signal = reconstructed_signal.flip(1)
|
|
||||||
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
|
||||||
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
|
||||||
|
|
||||||
return reconstructed_signal
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
import torch
|
|
||||||
from einops import rearrange
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
class Pretransform(nn.Module):
|
|
||||||
def __init__(self, enable_grad, io_channels, is_discrete):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.is_discrete = is_discrete
|
|
||||||
self.io_channels = io_channels
|
|
||||||
self.encoded_channels = None
|
|
||||||
self.downsampling_ratio = None
|
|
||||||
|
|
||||||
self.enable_grad = enable_grad
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def decode(self, z):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def tokenize(self, x):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class AutoencoderPretransform(Pretransform):
|
|
||||||
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
|
||||||
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
|
||||||
self.model = model
|
|
||||||
self.model.requires_grad_(False).eval()
|
|
||||||
self.scale=scale
|
|
||||||
self.downsampling_ratio = model.downsampling_ratio
|
|
||||||
self.io_channels = model.io_channels
|
|
||||||
self.sample_rate = model.sample_rate
|
|
||||||
|
|
||||||
self.model_half = model_half
|
|
||||||
self.iterate_batch = iterate_batch
|
|
||||||
|
|
||||||
self.encoded_channels = model.latent_dim
|
|
||||||
|
|
||||||
self.chunked = chunked
|
|
||||||
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
|
||||||
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
|
||||||
|
|
||||||
if self.model_half:
|
|
||||||
self.model.half()
|
|
||||||
|
|
||||||
def encode(self, x, **kwargs):
|
|
||||||
|
|
||||||
if self.model_half:
|
|
||||||
x = x.half()
|
|
||||||
self.model.to(torch.float16)
|
|
||||||
|
|
||||||
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
|
||||||
|
|
||||||
if self.model_half:
|
|
||||||
encoded = encoded.float()
|
|
||||||
|
|
||||||
return encoded / self.scale
|
|
||||||
|
|
||||||
def decode(self, z, **kwargs):
|
|
||||||
z = z * self.scale
|
|
||||||
|
|
||||||
if self.model_half:
|
|
||||||
z = z.half()
|
|
||||||
self.model.to(torch.float16)
|
|
||||||
|
|
||||||
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
|
||||||
|
|
||||||
if self.model_half:
|
|
||||||
decoded = decoded.float()
|
|
||||||
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
def tokenize(self, x, **kwargs):
|
|
||||||
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
|
||||||
|
|
||||||
_, info = self.model.encode(x, return_info = True, **kwargs)
|
|
||||||
|
|
||||||
return info[self.model.bottleneck.tokens_id]
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens, **kwargs):
|
|
||||||
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
|
||||||
|
|
||||||
return self.model.decode_tokens(tokens, **kwargs)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
|
||||||
self.model.load_state_dict(state_dict, strict=strict)
|
|
||||||
|
|
||||||
class PQMFPretransform(Pretransform):
|
|
||||||
def __init__(self, attenuation=100, num_bands=16):
|
|
||||||
# TODO: Fix PQMF to take in in-channels
|
|
||||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
|
||||||
from .pqmf import PQMF
|
|
||||||
self.pqmf = PQMF(attenuation, num_bands)
|
|
||||||
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
# x is (Batch x Channels x Time)
|
|
||||||
x = self.pqmf.forward(x)
|
|
||||||
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
|
||||||
# but Pretransform needs Batch x Channels x Time
|
|
||||||
# so concatenate channels and bands into one axis
|
|
||||||
return rearrange(x, "b c n t -> b (c n) t")
|
|
||||||
|
|
||||||
def decode(self, x):
|
|
||||||
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
|
||||||
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
|
||||||
# returns (Batch x Channels x Time)
|
|
||||||
return self.pqmf.inverse(x)
|
|
||||||
|
|
||||||
class PretrainedDACPretransform(Pretransform):
|
|
||||||
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
|
||||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
|
||||||
|
|
||||||
import dac
|
|
||||||
|
|
||||||
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
|
||||||
|
|
||||||
self.model = dac.DAC.load(model_path)
|
|
||||||
|
|
||||||
self.quantize_on_decode = quantize_on_decode
|
|
||||||
|
|
||||||
if model_type == "44khz":
|
|
||||||
self.downsampling_ratio = 512
|
|
||||||
else:
|
|
||||||
self.downsampling_ratio = 320
|
|
||||||
|
|
||||||
self.io_channels = 1
|
|
||||||
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
self.chunked = chunked
|
|
||||||
|
|
||||||
self.encoded_channels = self.model.latent_dim
|
|
||||||
|
|
||||||
self.num_quantizers = self.model.n_codebooks
|
|
||||||
|
|
||||||
self.codebook_size = self.model.codebook_size
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
|
|
||||||
latents = self.model.encoder(x)
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
output = latents
|
|
||||||
else:
|
|
||||||
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
|
||||||
output = z
|
|
||||||
|
|
||||||
if self.scale != 1.0:
|
|
||||||
output = output / self.scale
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def decode(self, z):
|
|
||||||
|
|
||||||
if self.scale != 1.0:
|
|
||||||
z = z * self.scale
|
|
||||||
|
|
||||||
if self.quantize_on_decode:
|
|
||||||
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
|
||||||
|
|
||||||
return self.model.decode(z)
|
|
||||||
|
|
||||||
def tokenize(self, x):
|
|
||||||
return self.model.encode(x)[1]
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens):
|
|
||||||
latents = self.model.quantizer.from_codes(tokens)
|
|
||||||
return self.model.decode(latents)
|
|
||||||
|
|
||||||
class AudiocraftCompressionPretransform(Pretransform):
|
|
||||||
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
|
||||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from audiocraft.models import CompressionModel
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
|
||||||
|
|
||||||
self.model = CompressionModel.get_pretrained(model_type)
|
|
||||||
|
|
||||||
self.quantize_on_decode = quantize_on_decode
|
|
||||||
|
|
||||||
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
|
||||||
|
|
||||||
self.sample_rate = self.model.sample_rate
|
|
||||||
|
|
||||||
self.io_channels = self.model.channels
|
|
||||||
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
#self.encoded_channels = self.model.latent_dim
|
|
||||||
|
|
||||||
self.num_quantizers = self.model.num_codebooks
|
|
||||||
|
|
||||||
self.codebook_size = self.model.cardinality
|
|
||||||
|
|
||||||
self.model.to(torch.float16).eval().requires_grad_(False)
|
|
||||||
|
|
||||||
def encode(self, x):
|
|
||||||
|
|
||||||
assert False, "Audiocraft compression models do not support continuous encoding"
|
|
||||||
|
|
||||||
# latents = self.model.encoder(x)
|
|
||||||
|
|
||||||
# if self.quantize_on_decode:
|
|
||||||
# output = latents
|
|
||||||
# else:
|
|
||||||
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
|
||||||
# output = z
|
|
||||||
|
|
||||||
# if self.scale != 1.0:
|
|
||||||
# output = output / self.scale
|
|
||||||
|
|
||||||
# return output
|
|
||||||
|
|
||||||
def decode(self, z):
|
|
||||||
|
|
||||||
assert False, "Audiocraft compression models do not support continuous decoding"
|
|
||||||
|
|
||||||
# if self.scale != 1.0:
|
|
||||||
# z = z * self.scale
|
|
||||||
|
|
||||||
# if self.quantize_on_decode:
|
|
||||||
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
|
||||||
|
|
||||||
# return self.model.decode(z)
|
|
||||||
|
|
||||||
def tokenize(self, x):
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
return self.model.encode(x.to(torch.float16))[0]
|
|
||||||
|
|
||||||
def decode_tokens(self, tokens):
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
return self.model.decode(tokens)
|
|
||||||
@@ -1,989 +0,0 @@
|
|||||||
from functools import reduce, partial
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from einops.layers.torch import Rearrange
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn, einsum
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
|
||||||
from typing import Callable, Literal
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
flash_attn_kvpacked_func = None
|
|
||||||
flash_attn_func = None
|
|
||||||
|
|
||||||
from .utils import compile, checkpoint
|
|
||||||
try:
|
|
||||||
import natten
|
|
||||||
except ImportError:
|
|
||||||
natten = None
|
|
||||||
|
|
||||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
||||||
return x * (1 + scale) + shift
|
|
||||||
|
|
||||||
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
|
||||||
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
|
||||||
|
|
||||||
def create_causal_mask(i, j, device):
|
|
||||||
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
|
||||||
|
|
||||||
def or_reduce(masks):
|
|
||||||
head, *body = masks
|
|
||||||
for rest in body:
|
|
||||||
head = head | rest
|
|
||||||
return head
|
|
||||||
|
|
||||||
# positional embeddings
|
|
||||||
|
|
||||||
class AbsolutePositionalEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim, max_seq_len):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = dim ** -0.5
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
self.emb = nn.Embedding(max_seq_len, dim)
|
|
||||||
|
|
||||||
def forward(self, x, pos = None, seq_start_pos = None):
|
|
||||||
seq_len, device = x.shape[1], x.device
|
|
||||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
|
||||||
|
|
||||||
if pos is None:
|
|
||||||
pos = torch.arange(seq_len, device = device)
|
|
||||||
|
|
||||||
if seq_start_pos is not None:
|
|
||||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
|
||||||
|
|
||||||
pos_emb = self.emb(pos)
|
|
||||||
pos_emb = pos_emb * self.scale
|
|
||||||
return pos_emb
|
|
||||||
|
|
||||||
class ScaledSinusoidalEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim, theta = 10000):
|
|
||||||
super().__init__()
|
|
||||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
|
||||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
|
||||||
|
|
||||||
half_dim = dim // 2
|
|
||||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
|
||||||
inv_freq = theta ** -freq_seq
|
|
||||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
|
||||||
|
|
||||||
def forward(self, x, pos = None, seq_start_pos = None):
|
|
||||||
seq_len, device = x.shape[1], x.device
|
|
||||||
|
|
||||||
if pos is None:
|
|
||||||
pos = torch.arange(seq_len, device = device)
|
|
||||||
|
|
||||||
if seq_start_pos is not None:
|
|
||||||
pos = pos - seq_start_pos[..., None]
|
|
||||||
|
|
||||||
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
|
||||||
return emb * self.scale
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
use_xpos = False,
|
|
||||||
scale_base = 512,
|
|
||||||
interpolation_factor = 1.,
|
|
||||||
base = 10000,
|
|
||||||
base_rescale_factor = 1.
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
||||||
# has some connection to NTK literature
|
|
||||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
||||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
|
||||||
|
|
||||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
self.register_buffer('inv_freq', inv_freq)
|
|
||||||
|
|
||||||
assert interpolation_factor >= 1.
|
|
||||||
self.interpolation_factor = interpolation_factor
|
|
||||||
|
|
||||||
if not use_xpos:
|
|
||||||
self.register_buffer('scale', None)
|
|
||||||
return
|
|
||||||
|
|
||||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
||||||
|
|
||||||
self.scale_base = scale_base
|
|
||||||
self.register_buffer('scale', scale)
|
|
||||||
|
|
||||||
def forward_from_seq_len(self, seq_len):
|
|
||||||
device = self.inv_freq.device
|
|
||||||
|
|
||||||
t = torch.arange(seq_len, device = device)
|
|
||||||
return self.forward(t)
|
|
||||||
|
|
||||||
@autocast(enabled = False)
|
|
||||||
def forward(self, t):
|
|
||||||
device = self.inv_freq.device
|
|
||||||
|
|
||||||
t = t.to(torch.float32)
|
|
||||||
|
|
||||||
t = t / self.interpolation_factor
|
|
||||||
|
|
||||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
|
||||||
|
|
||||||
if self.scale is None:
|
|
||||||
return freqs, 1.
|
|
||||||
|
|
||||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
|
||||||
scale = self.scale ** rearrange(power, 'n -> n 1')
|
|
||||||
scale = torch.cat((scale, scale), dim = -1)
|
|
||||||
|
|
||||||
return freqs, scale
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
|
||||||
x1, x2 = x.unbind(dim = -2)
|
|
||||||
return torch.cat((-x2, x1), dim = -1)
|
|
||||||
|
|
||||||
@autocast(enabled = False)
|
|
||||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
||||||
out_dtype = t.dtype
|
|
||||||
|
|
||||||
# cast to float32 if necessary for numerical stability
|
|
||||||
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
|
||||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
|
||||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
|
||||||
freqs = freqs[-seq_len:, :]
|
|
||||||
|
|
||||||
if t.ndim == 4 and freqs.ndim == 3:
|
|
||||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
|
||||||
|
|
||||||
# partial rotary embeddings, Wang et al. GPT-J
|
|
||||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
|
||||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
|
||||||
|
|
||||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
|
||||||
|
|
||||||
return torch.cat((t, t_unrotated), dim = -1)
|
|
||||||
|
|
||||||
# norms
|
|
||||||
class DynamicTanh(nn.Module):
|
|
||||||
def __init__(self, dim, init_alpha=10.0):
|
|
||||||
super().__init__()
|
|
||||||
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
|
||||||
self.gamma = nn.Parameter(torch.ones(dim))
|
|
||||||
self.beta = nn.Parameter(torch.zeros(dim))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.tanh(self.alpha * x)
|
|
||||||
return self.gamma * x + self.beta
|
|
||||||
|
|
||||||
class RunningInstanceNorm(nn.Module):
|
|
||||||
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
|
||||||
self.register_buffer("running_std", torch.ones(1,1,dim))
|
|
||||||
self.saturate = saturate
|
|
||||||
self.eps = eps
|
|
||||||
self.momentum = momentum
|
|
||||||
self.dim = dim
|
|
||||||
self.trainable_gain = trainable_gain
|
|
||||||
if self.trainable_gain:
|
|
||||||
self.gain = nn.Parameter(torch.ones(1))
|
|
||||||
|
|
||||||
def _update_stats(self, x):
|
|
||||||
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
|
||||||
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.training:
|
|
||||||
self._update_stats(x)
|
|
||||||
x = (x - self.running_mean) / self.running_std
|
|
||||||
if self.saturate:
|
|
||||||
x = torch.asinh(x)
|
|
||||||
if self.trainable_gain:
|
|
||||||
x = x * self.gain
|
|
||||||
return x
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
|
||||||
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
|
||||||
"""
|
|
||||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if fix_scale:
|
|
||||||
self.register_buffer("gamma", torch.ones(dim))
|
|
||||||
else:
|
|
||||||
self.gamma = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.beta = nn.Parameter(torch.zeros(dim))
|
|
||||||
else:
|
|
||||||
self.register_buffer("beta", torch.zeros(dim))
|
|
||||||
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
self.force_fp32 = force_fp32
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if not self.force_fp32:
|
|
||||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
|
||||||
else:
|
|
||||||
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
|
||||||
return output.to(x.dtype)
|
|
||||||
|
|
||||||
class LayerScale(nn.Module):
|
|
||||||
def __init__(self, dim, init_val = 1e-5):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.full([dim], init_val))
|
|
||||||
def forward(self, x):
|
|
||||||
return x * self.scale
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim_in,
|
|
||||||
dim_out,
|
|
||||||
activation: Callable,
|
|
||||||
use_conv = False,
|
|
||||||
conv_kernel_size = 3,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.act = activation
|
|
||||||
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
|
||||||
self.use_conv = use_conv
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.use_conv:
|
|
||||||
x = rearrange(x, 'b n d -> b d n')
|
|
||||||
x = self.proj(x)
|
|
||||||
x = rearrange(x, 'b d n -> b n d')
|
|
||||||
else:
|
|
||||||
x = self.proj(x)
|
|
||||||
|
|
||||||
x, gate = x.chunk(2, dim = -1)
|
|
||||||
return x * self.act(gate)
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_out = None,
|
|
||||||
mult = 4,
|
|
||||||
no_bias = False,
|
|
||||||
glu = True,
|
|
||||||
use_conv = False,
|
|
||||||
conv_kernel_size = 3,
|
|
||||||
zero_init_output = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
inner_dim = int(dim * mult)
|
|
||||||
|
|
||||||
# Default to SwiGLU
|
|
||||||
|
|
||||||
activation = nn.SiLU()
|
|
||||||
|
|
||||||
dim_out = dim if dim_out is None else dim_out
|
|
||||||
|
|
||||||
if glu:
|
|
||||||
linear_in = GLU(dim, inner_dim, activation)
|
|
||||||
else:
|
|
||||||
linear_in = nn.Sequential(
|
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
||||||
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
||||||
activation
|
|
||||||
)
|
|
||||||
|
|
||||||
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
|
||||||
|
|
||||||
# init last linear layer to 0
|
|
||||||
if zero_init_output:
|
|
||||||
nn.init.zeros_(linear_out.weight)
|
|
||||||
if not no_bias:
|
|
||||||
nn.init.zeros_(linear_out.bias)
|
|
||||||
|
|
||||||
|
|
||||||
self.ff = nn.Sequential(
|
|
||||||
linear_in,
|
|
||||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
|
||||||
linear_out,
|
|
||||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.ff(x)
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_heads = 64,
|
|
||||||
dim_context = None,
|
|
||||||
causal = False,
|
|
||||||
zero_init_output=True,
|
|
||||||
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
|
||||||
differential = False,
|
|
||||||
feat_scale = False
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.dim_heads = dim_heads
|
|
||||||
self.differential = differential
|
|
||||||
|
|
||||||
dim_kv = dim_context if dim_context is not None else dim
|
|
||||||
|
|
||||||
self.num_heads = dim // dim_heads
|
|
||||||
self.kv_heads = dim_kv // dim_heads
|
|
||||||
|
|
||||||
if dim_context is not None:
|
|
||||||
if differential:
|
|
||||||
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
|
||||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
|
||||||
else:
|
|
||||||
self.to_q = nn.Linear(dim, dim, bias=False)
|
|
||||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
|
||||||
else:
|
|
||||||
if differential:
|
|
||||||
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
|
||||||
else:
|
|
||||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
||||||
|
|
||||||
self.to_out = nn.Linear(dim, dim, bias=False)
|
|
||||||
|
|
||||||
if zero_init_output:
|
|
||||||
nn.init.zeros_(self.to_out.weight)
|
|
||||||
|
|
||||||
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
|
||||||
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
|
||||||
|
|
||||||
self.qk_norm = qk_norm
|
|
||||||
if self.qk_norm == "ln":
|
|
||||||
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
|
||||||
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
|
||||||
elif self.qk_norm == 'rns':
|
|
||||||
self.q_norm = nn.RMSNorm(dim_heads)
|
|
||||||
self.k_norm = nn.RMSNorm(dim_heads)
|
|
||||||
elif self.qk_norm == 'dyt':
|
|
||||||
self.q_norm = DynamicTanh(dim_heads)
|
|
||||||
self.k_norm = DynamicTanh(dim_heads)
|
|
||||||
|
|
||||||
self.sdp_kwargs = dict(
|
|
||||||
enable_flash = True,
|
|
||||||
enable_math = True,
|
|
||||||
enable_mem_efficient = True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.feat_scale = feat_scale
|
|
||||||
|
|
||||||
if self.feat_scale:
|
|
||||||
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
|
||||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
|
||||||
|
|
||||||
self.causal = causal
|
|
||||||
|
|
||||||
@compile
|
|
||||||
def apply_qk_layernorm(self, q, k):
|
|
||||||
q_type = q.dtype
|
|
||||||
k_type = k.dtype
|
|
||||||
q = self.q_norm(q).to(q_type)
|
|
||||||
k = self.k_norm(k).to(k_type)
|
|
||||||
return q, k
|
|
||||||
|
|
||||||
|
|
||||||
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
|
||||||
|
|
||||||
if self.num_heads != self.kv_heads:
|
|
||||||
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
|
||||||
heads_per_kv_head = self.num_heads // self.kv_heads
|
|
||||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
|
||||||
|
|
||||||
flash_attn_available = HAS_FLASH_ATTN
|
|
||||||
|
|
||||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
|
||||||
flex_attention_block_mask = None
|
|
||||||
flex_attention_score_mod = None
|
|
||||||
|
|
||||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlexAttention is not available in this build. "
|
|
||||||
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
|
||||||
)
|
|
||||||
elif flash_attn_available:
|
|
||||||
fa_dtype_in = q.dtype
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
|
||||||
|
|
||||||
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
|
||||||
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
|
||||||
|
|
||||||
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
|
||||||
|
|
||||||
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
|
||||||
else:
|
|
||||||
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
#@compile
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
context = None,
|
|
||||||
rotary_pos_emb = None,
|
|
||||||
causal = None,
|
|
||||||
flex_attention_block_mask = None,
|
|
||||||
flex_attention_score_mod = None,
|
|
||||||
flash_attn_sliding_window = None
|
|
||||||
):
|
|
||||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
|
||||||
|
|
||||||
kv_input = context if has_context else x
|
|
||||||
|
|
||||||
if hasattr(self, 'to_q'):
|
|
||||||
# Use separate linear projections for q and k/v
|
|
||||||
if self.differential:
|
|
||||||
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
|
||||||
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
|
||||||
q = torch.stack([q, q_diff], dim = 1)
|
|
||||||
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
|
||||||
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
|
||||||
k = torch.stack([k, k_diff], dim = 1)
|
|
||||||
else:
|
|
||||||
q = self.to_q(x)
|
|
||||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
|
||||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
|
||||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
|
||||||
else:
|
|
||||||
# Use fused linear projection
|
|
||||||
if self.differential:
|
|
||||||
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
|
||||||
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
|
||||||
q = torch.stack([q, q_diff], dim = 1)
|
|
||||||
k = torch.stack([k, k_diff], dim = 1)
|
|
||||||
else:
|
|
||||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
|
||||||
|
|
||||||
# Normalize q and k for cosine sim attention
|
|
||||||
if self.qk_norm == "l2":
|
|
||||||
q = F.normalize(q, dim=-1)
|
|
||||||
k = F.normalize(k, dim=-1)
|
|
||||||
elif self.qk_norm != "none":
|
|
||||||
q, k = self.apply_qk_layernorm(q, k)
|
|
||||||
|
|
||||||
if rotary_pos_emb is not None:
|
|
||||||
freqs, _ = rotary_pos_emb
|
|
||||||
q_dtype = q.dtype
|
|
||||||
k_dtype = k.dtype
|
|
||||||
q = q.to(torch.float32)
|
|
||||||
k = k.to(torch.float32)
|
|
||||||
freqs = freqs.to(torch.float32)
|
|
||||||
if q.shape[-2] >= k.shape[-2]:
|
|
||||||
ratio = q.shape[-2] / k.shape[-2]
|
|
||||||
q_freqs, k_freqs = freqs, ratio * freqs
|
|
||||||
else:
|
|
||||||
ratio = k.shape[-2] / q.shape[-2]
|
|
||||||
q_freqs, k_freqs = ratio * freqs, freqs
|
|
||||||
q = apply_rotary_pos_emb(q, q_freqs)
|
|
||||||
k = apply_rotary_pos_emb(k, k_freqs)
|
|
||||||
q = q.to(v.dtype)
|
|
||||||
k = k.to(v.dtype)
|
|
||||||
|
|
||||||
n, device = q.shape[-2], q.device
|
|
||||||
|
|
||||||
causal = self.causal if causal is None else causal
|
|
||||||
|
|
||||||
if n == 1 and causal:
|
|
||||||
causal = False
|
|
||||||
|
|
||||||
if self.differential:
|
|
||||||
q, q_diff = q.unbind(dim = 1)
|
|
||||||
k, k_diff = k.unbind(dim = 1)
|
|
||||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
||||||
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
||||||
out = out - out_diff
|
|
||||||
else:
|
|
||||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
|
||||||
|
|
||||||
# merge heads
|
|
||||||
out = rearrange(out, ' b h n d -> b n (h d)')
|
|
||||||
|
|
||||||
# Communicate between heads
|
|
||||||
|
|
||||||
# with autocast(enabled = False):
|
|
||||||
# out_dtype = out.dtype
|
|
||||||
# out = out.to(torch.float32)
|
|
||||||
# out = self.to_out(out).to(out_dtype)
|
|
||||||
out = self.to_out(out)
|
|
||||||
|
|
||||||
if self.feat_scale:
|
|
||||||
out_dc = out.mean(dim=-2, keepdim=True)
|
|
||||||
out_hf = out - out_dc
|
|
||||||
|
|
||||||
# Selectively modulate DC and high frequency components
|
|
||||||
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class ConformerModule(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
norm_kwargs = {},
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
|
||||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
|
||||||
self.glu = GLU(dim, dim, nn.SiLU())
|
|
||||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
|
||||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
|
||||||
self.swish = nn.SiLU()
|
|
||||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
|
||||||
|
|
||||||
#@compile
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.in_norm(x)
|
|
||||||
x = rearrange(x, 'b n d -> b d n')
|
|
||||||
x = self.pointwise_conv(x)
|
|
||||||
x = rearrange(x, 'b d n -> b n d')
|
|
||||||
x = self.glu(x)
|
|
||||||
x = rearrange(x, 'b n d -> b d n')
|
|
||||||
x = self.depthwise_conv(x)
|
|
||||||
x = rearrange(x, 'b d n -> b n d')
|
|
||||||
x = self.mid_norm(x)
|
|
||||||
x = self.swish(x)
|
|
||||||
x = rearrange(x, 'b n d -> b d n')
|
|
||||||
x = self.pointwise_conv_2(x)
|
|
||||||
x = rearrange(x, 'b d n -> b n d')
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
dim_heads = 64,
|
|
||||||
cross_attend = False,
|
|
||||||
dim_context = None,
|
|
||||||
global_cond_dim = None,
|
|
||||||
causal = False,
|
|
||||||
zero_init_branch_outputs = True,
|
|
||||||
conformer = False,
|
|
||||||
layer_ix = -1,
|
|
||||||
remove_norms = False,
|
|
||||||
add_rope = False,
|
|
||||||
layer_scale = False,
|
|
||||||
use_sync_block_film = False,
|
|
||||||
attn_kwargs = {},
|
|
||||||
ff_kwargs = {},
|
|
||||||
norm_kwargs = {}
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.dim_heads = min(dim_heads,dim)
|
|
||||||
self.cross_attend = cross_attend
|
|
||||||
self.dim_context = dim_context
|
|
||||||
self.causal = causal
|
|
||||||
if layer_scale and zero_init_branch_outputs:
|
|
||||||
zero_init_branch_outputs = False
|
|
||||||
|
|
||||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
||||||
|
|
||||||
self.add_rope = add_rope
|
|
||||||
|
|
||||||
self.self_attn = Attention(
|
|
||||||
dim,
|
|
||||||
dim_heads = self.dim_heads,
|
|
||||||
causal = causal,
|
|
||||||
zero_init_output=zero_init_branch_outputs,
|
|
||||||
**attn_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
||||||
|
|
||||||
self.cross_attend = cross_attend
|
|
||||||
if cross_attend:
|
|
||||||
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
||||||
self.cross_attn = Attention(
|
|
||||||
dim,
|
|
||||||
dim_heads = self.dim_heads,
|
|
||||||
dim_context=dim_context,
|
|
||||||
causal = causal,
|
|
||||||
zero_init_output=zero_init_branch_outputs,
|
|
||||||
**attn_kwargs
|
|
||||||
)
|
|
||||||
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
||||||
|
|
||||||
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
|
||||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
|
||||||
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
||||||
|
|
||||||
self.layer_ix = layer_ix
|
|
||||||
|
|
||||||
self.conformer = None
|
|
||||||
if conformer:
|
|
||||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
|
||||||
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
|
||||||
|
|
||||||
self.global_cond_dim = global_cond_dim
|
|
||||||
if global_cond_dim is not None:
|
|
||||||
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
|
||||||
|
|
||||||
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
|
||||||
|
|
||||||
if use_sync_block_film:
|
|
||||||
self.sync_film_generator = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
||||||
)
|
|
||||||
|
|
||||||
@compile
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
context = None,
|
|
||||||
global_cond=None,
|
|
||||||
rotary_pos_emb = None,
|
|
||||||
self_attention_block_mask = None,
|
|
||||||
self_attention_score_mod = None,
|
|
||||||
cross_attention_block_mask = None,
|
|
||||||
cross_attention_score_mod = None,
|
|
||||||
self_attention_flash_sliding_window = None,
|
|
||||||
cross_attention_flash_sliding_window = None,
|
|
||||||
sync_cond = None,
|
|
||||||
prepend_length=0
|
|
||||||
):
|
|
||||||
if rotary_pos_emb is None and self.add_rope:
|
|
||||||
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
|
||||||
|
|
||||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
|
||||||
if len(global_cond.shape) == 2:
|
|
||||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
|
||||||
else:
|
|
||||||
|
|
||||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
|
||||||
|
|
||||||
# self-attention with adaLN
|
|
||||||
residual = x
|
|
||||||
x = self.pre_norm(x)
|
|
||||||
x = x * (1 + scale_self) + shift_self
|
|
||||||
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
|
||||||
x = x * torch.sigmoid(1 - gate_self)
|
|
||||||
x = self.self_attn_scale(x)
|
|
||||||
x = x + residual
|
|
||||||
|
|
||||||
if context is not None and self.cross_attend:
|
|
||||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
|
||||||
|
|
||||||
if self.conformer is not None:
|
|
||||||
x = x + self.conformer_scale(self.conformer(x))
|
|
||||||
|
|
||||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
|
||||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
||||||
x = x * (1 + scale) + shift
|
|
||||||
|
|
||||||
# feedforward with adaLN
|
|
||||||
residual = x
|
|
||||||
x = self.ff_norm(x)
|
|
||||||
x = x * (1 + scale_ff) + shift_ff
|
|
||||||
x = self.ff(x)
|
|
||||||
x = x * torch.sigmoid(1 - gate_ff)
|
|
||||||
x = self.ff_scale(x)
|
|
||||||
x = x + residual
|
|
||||||
|
|
||||||
else:
|
|
||||||
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
|
||||||
|
|
||||||
if context is not None and self.cross_attend:
|
|
||||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
|
||||||
|
|
||||||
if self.conformer is not None:
|
|
||||||
x = x + self.conformer_scale(self.conformer(x))
|
|
||||||
|
|
||||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
|
||||||
prepend_part = x[:, :prepend_length, :]
|
|
||||||
audio_part = x[:, prepend_length:, :]
|
|
||||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
||||||
modulated_audio_part = audio_part * (1 + scale) + shift
|
|
||||||
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
|
||||||
|
|
||||||
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
class ContinuousTransformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
depth,
|
|
||||||
*,
|
|
||||||
dim_in = None,
|
|
||||||
dim_out = None,
|
|
||||||
dim_heads = 64,
|
|
||||||
cross_attend=False,
|
|
||||||
cond_token_dim=None,
|
|
||||||
pre_cross_attn_ix=-1,
|
|
||||||
final_cross_attn_ix=-1,
|
|
||||||
global_cond_dim=None,
|
|
||||||
causal=False,
|
|
||||||
rotary_pos_emb=True,
|
|
||||||
zero_init_branch_outputs=True,
|
|
||||||
conformer=False,
|
|
||||||
use_sinusoidal_emb=False,
|
|
||||||
use_abs_pos_emb=False,
|
|
||||||
abs_pos_emb_max_length=10000,
|
|
||||||
num_memory_tokens=0,
|
|
||||||
sliding_window=None,
|
|
||||||
use_mlp=False,
|
|
||||||
use_add_norm=False,
|
|
||||||
use_gated=False,
|
|
||||||
use_final_layer=False,
|
|
||||||
use_zeros=False,
|
|
||||||
use_conv=False,
|
|
||||||
use_fusion_mlp=False,
|
|
||||||
use_film=False,
|
|
||||||
use_sync_film=False,
|
|
||||||
use_sync_gated=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.depth = depth
|
|
||||||
self.causal = causal
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
if use_mlp:
|
|
||||||
self.project_in = nn.Sequential(
|
|
||||||
nn.Linear(dim_in, dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim, bias=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
|
||||||
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
|
||||||
self.video_temporal_conv = None
|
|
||||||
self.audio_temporal_conv = None
|
|
||||||
self.fusion_mlp = None
|
|
||||||
if use_conv:
|
|
||||||
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
|
||||||
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
|
||||||
if use_fusion_mlp:
|
|
||||||
self.fusion_mlp = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
if rotary_pos_emb:
|
|
||||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
|
||||||
else:
|
|
||||||
self.rotary_pos_emb = None
|
|
||||||
self.num_memory_tokens = num_memory_tokens
|
|
||||||
if num_memory_tokens > 0:
|
|
||||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
|
||||||
|
|
||||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
|
||||||
if use_sinusoidal_emb:
|
|
||||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
|
||||||
|
|
||||||
self.use_abs_pos_emb = use_abs_pos_emb
|
|
||||||
if use_abs_pos_emb:
|
|
||||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
|
||||||
|
|
||||||
self.adaLN_modulation = None
|
|
||||||
if global_cond_dim is not None:
|
|
||||||
if use_final_layer:
|
|
||||||
self.norm_final = LayerNorm(dim)
|
|
||||||
self.adaLN_modulation = nn.Sequential(
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(
|
|
||||||
dim, 2 * dim, bias=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_zeros:
|
|
||||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
|
||||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
|
||||||
nn.init.constant_(self.project_out.weight, 0)
|
|
||||||
self.global_cond_embedder = nn.Sequential(
|
|
||||||
nn.Linear(global_cond_dim, dim),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim * 6)
|
|
||||||
)
|
|
||||||
if use_zeros:
|
|
||||||
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
|
||||||
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
|
||||||
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
|
||||||
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
|
||||||
|
|
||||||
self.final_cross_attn_ix = final_cross_attn_ix
|
|
||||||
self.use_gated = use_gated
|
|
||||||
self.use_film = use_film
|
|
||||||
self.use_add_norm = use_add_norm
|
|
||||||
if self.use_add_norm:
|
|
||||||
self.add_norm = nn.LayerNorm(dim)
|
|
||||||
if use_gated:
|
|
||||||
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
|
||||||
|
|
||||||
if use_film:
|
|
||||||
self.film_generator = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.film_generator = None
|
|
||||||
|
|
||||||
if use_sync_film:
|
|
||||||
self.sync_film_generator = nn.Sequential(
|
|
||||||
nn.Linear(dim, dim, bias=False),
|
|
||||||
nn.SiLU(),
|
|
||||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.sync_film_generator = None
|
|
||||||
if use_sync_gated:
|
|
||||||
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
|
||||||
else:
|
|
||||||
self.sync_gate = None
|
|
||||||
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
|
|
||||||
for i in range(depth):
|
|
||||||
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
|
||||||
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
|
||||||
self.layers.append(
|
|
||||||
TransformerBlock(
|
|
||||||
dim,
|
|
||||||
dim_heads = dim_heads,
|
|
||||||
cross_attend = should_cross_attend,
|
|
||||||
dim_context = cond_token_dim,
|
|
||||||
global_cond_dim = global_cond_dim,
|
|
||||||
causal = causal,
|
|
||||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
|
||||||
conformer=conformer,
|
|
||||||
layer_ix=i,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
mask = None,
|
|
||||||
prepend_embeds = None,
|
|
||||||
prepend_mask = None,
|
|
||||||
add_cond = None,
|
|
||||||
sync_cond = None,
|
|
||||||
global_cond = None,
|
|
||||||
return_info = False,
|
|
||||||
use_checkpointing = True,
|
|
||||||
exit_layer_ix = None,
|
|
||||||
video_dropout_prob = 0.0,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
batch, seq, device = *x.shape[:2], x.device
|
|
||||||
model_dtype = next(self.parameters()).dtype
|
|
||||||
x = x.to(model_dtype)
|
|
||||||
|
|
||||||
prepend_length = 0
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"hidden_states": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
x = self.project_in(x)
|
|
||||||
if add_cond is not None:
|
|
||||||
if self.use_gated:
|
|
||||||
gate = torch.sigmoid(self.gate)
|
|
||||||
x = x + gate * add_cond
|
|
||||||
elif self.use_film:
|
|
||||||
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
|
||||||
x = x * (1 + scale) + shift
|
|
||||||
else:
|
|
||||||
x = x + add_cond
|
|
||||||
|
|
||||||
if self.use_add_norm:
|
|
||||||
x = self.add_norm(x)
|
|
||||||
if self.fusion_mlp is not None:
|
|
||||||
x = self.fusion_mlp(x)
|
|
||||||
|
|
||||||
if sync_cond is not None:
|
|
||||||
# Resample sync_cond to match audio sequence length if needed
|
|
||||||
if sync_cond.shape[1] != x.shape[1]:
|
|
||||||
sync_cond = torch.nn.functional.interpolate(
|
|
||||||
sync_cond.transpose(1, 2), size=x.shape[1],
|
|
||||||
mode='linear', align_corners=False,
|
|
||||||
).transpose(1, 2)
|
|
||||||
if self.sync_film_generator is not None:
|
|
||||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
|
||||||
x = x * (1 + scale) + shift
|
|
||||||
elif self.sync_gate is not None:
|
|
||||||
gate_value = torch.sigmoid(self.sync_gate)
|
|
||||||
x = x + gate_value * sync_cond
|
|
||||||
# else:
|
|
||||||
# x = x + sync_cond
|
|
||||||
|
|
||||||
if prepend_embeds is not None:
|
|
||||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
|
||||||
|
|
||||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
|
||||||
|
|
||||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
|
||||||
|
|
||||||
if self.num_memory_tokens > 0:
|
|
||||||
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
|
||||||
x = torch.cat((memory_tokens, x), dim=1)
|
|
||||||
|
|
||||||
if self.rotary_pos_emb is not None:
|
|
||||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
|
||||||
else:
|
|
||||||
rotary_pos_emb = None
|
|
||||||
|
|
||||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
|
||||||
x = x + self.pos_emb(x)
|
|
||||||
|
|
||||||
if global_cond is not None and self.global_cond_embedder is not None:
|
|
||||||
global_cond_embed = self.global_cond_embedder(global_cond)
|
|
||||||
else:
|
|
||||||
global_cond_embed = global_cond
|
|
||||||
# Iterate over the transformer layers
|
|
||||||
for layer_ix, layer in enumerate(self.layers):
|
|
||||||
if use_checkpointing:
|
|
||||||
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
|
||||||
else:
|
|
||||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
info["hidden_states"].append(x)
|
|
||||||
|
|
||||||
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
|
||||||
x = x[:, self.num_memory_tokens:, :]
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
x = x[:, self.num_memory_tokens:, :]
|
|
||||||
if global_cond is not None and self.adaLN_modulation is not None:
|
|
||||||
if len(global_cond.shape) == 2:
|
|
||||||
global_cond = global_cond.unsqueeze(1)
|
|
||||||
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
|
||||||
x = modulate(self.norm_final(x), shift, scale)
|
|
||||||
x = self.project_out(x)
|
|
||||||
|
|
||||||
if return_info:
|
|
||||||
return x, info
|
|
||||||
|
|
||||||
return x
|
|
||||||
@@ -1,180 +0,0 @@
|
|||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
|
|
||||||
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
|
|
||||||
from torch.nn.utils import remove_weight_norm
|
|
||||||
|
|
||||||
def load_ckpt_state_dict(ckpt_path, prefix=None):
|
|
||||||
if ckpt_path.endswith(".safetensors"):
|
|
||||||
state_dict = load_file(ckpt_path)
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
|
||||||
|
|
||||||
# 过滤特定前缀的state_dict
|
|
||||||
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
|
|
||||||
|
|
||||||
return filtered_state_dict
|
|
||||||
|
|
||||||
def remove_weight_norm_from_model(model):
|
|
||||||
for module in model.modules():
|
|
||||||
if hasattr(module, "weight"):
|
|
||||||
remove_weight_norm(module)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
|
||||||
# License can be found in LICENSES/LICENSE_META.txt
|
|
||||||
|
|
||||||
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
|
||||||
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (torch.Tensor): The input tensor containing probabilities.
|
|
||||||
num_samples (int): Number of samples to draw.
|
|
||||||
replacement (bool): Whether to draw with replacement or not.
|
|
||||||
Keywords args:
|
|
||||||
generator (torch.Generator): A pseudorandom number generator for sampling.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Last dimension contains num_samples indices
|
|
||||||
sampled from the multinomial probability distribution
|
|
||||||
located in the last dimension of tensor input.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if num_samples == 1:
|
|
||||||
q = torch.empty_like(input).exponential_(1, generator=generator)
|
|
||||||
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
|
||||||
|
|
||||||
input_ = input.reshape(-1, input.shape[-1])
|
|
||||||
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
|
||||||
output = output_.reshape(*list(input.shape[:-1]), -1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
|
||||||
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
|
||||||
k (int): The k in “top-k”.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Sampled tokens.
|
|
||||||
"""
|
|
||||||
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
|
||||||
min_value_top_k = top_k_value[..., [-1]]
|
|
||||||
probs *= (probs >= min_value_top_k).float()
|
|
||||||
probs.div_(probs.sum(dim=-1, keepdim=True))
|
|
||||||
next_token = multinomial(probs, num_samples=1)
|
|
||||||
return next_token
|
|
||||||
|
|
||||||
|
|
||||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
|
||||||
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
|
||||||
p (int): The p in “top-p”.
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Sampled tokens.
|
|
||||||
"""
|
|
||||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
mask = probs_sum - probs_sort > p
|
|
||||||
probs_sort *= (~mask).float()
|
|
||||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
||||||
next_token = multinomial(probs_sort, num_samples=1)
|
|
||||||
next_token = torch.gather(probs_idx, -1, next_token)
|
|
||||||
return next_token
|
|
||||||
|
|
||||||
def next_power_of_two(n):
|
|
||||||
return 2 ** (n - 1).bit_length()
|
|
||||||
|
|
||||||
def next_multiple_of_64(n):
|
|
||||||
return ((n + 63) // 64) * 64
|
|
||||||
|
|
||||||
|
|
||||||
# mask construction helpers
|
|
||||||
|
|
||||||
def mask_from_start_end_indices(
|
|
||||||
seq_len: int,
|
|
||||||
start: Tensor,
|
|
||||||
end: Tensor
|
|
||||||
):
|
|
||||||
assert start.shape == end.shape
|
|
||||||
device = start.device
|
|
||||||
|
|
||||||
seq = torch.arange(seq_len, device = device, dtype = torch.long)
|
|
||||||
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
|
|
||||||
seq = seq.expand(*start.shape, seq_len)
|
|
||||||
|
|
||||||
mask = seq >= start[..., None].long()
|
|
||||||
mask &= seq < end[..., None].long()
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def mask_from_frac_lengths(
|
|
||||||
seq_len: int,
|
|
||||||
frac_lengths: Tensor
|
|
||||||
):
|
|
||||||
device = frac_lengths.device
|
|
||||||
|
|
||||||
lengths = (frac_lengths * seq_len).long()
|
|
||||||
max_start = seq_len - lengths
|
|
||||||
|
|
||||||
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
|
|
||||||
start = (max_start * rand).clamp(min = 0)
|
|
||||||
end = start + lengths
|
|
||||||
|
|
||||||
return mask_from_start_end_indices(seq_len, start, end)
|
|
||||||
|
|
||||||
def _build_spline(video_feat, video_t, target_t):
|
|
||||||
# 三次样条插值核心实现
|
|
||||||
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
|
|
||||||
spline = NaturalCubicSpline(coeffs)
|
|
||||||
return spline.evaluate(target_t).permute(0,2,1)
|
|
||||||
|
|
||||||
def resample(video_feat, audio_latent):
|
|
||||||
"""
|
|
||||||
9s
|
|
||||||
video_feat: [B, 72, D]
|
|
||||||
audio_latent: [B, D', 194] or int
|
|
||||||
"""
|
|
||||||
B, Tv, D = video_feat.shape
|
|
||||||
|
|
||||||
if isinstance(audio_latent, torch.Tensor):
|
|
||||||
# audio_latent is a tensor
|
|
||||||
if audio_latent.shape[1] != 64:
|
|
||||||
Ta = audio_latent.shape[1]
|
|
||||||
else:
|
|
||||||
Ta = audio_latent.shape[2]
|
|
||||||
elif isinstance(audio_latent, int):
|
|
||||||
# audio_latent is an int
|
|
||||||
Ta = audio_latent
|
|
||||||
else:
|
|
||||||
raise TypeError("audio_latent must be either a tensor or an int")
|
|
||||||
|
|
||||||
# 构建时间戳 (关键改进点)
|
|
||||||
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
|
|
||||||
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
|
|
||||||
|
|
||||||
# 三维化处理 (Batch, Feature, Time)
|
|
||||||
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
|
|
||||||
|
|
||||||
# 三次样条插值
|
|
||||||
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
|
||||||
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
|
||||||
|
|
||||||
def checkpoint(function, *args, **kwargs):
|
|
||||||
kwargs.setdefault("use_reentrant", False)
|
|
||||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
|
||||||
|
|
||||||
import os
|
|
||||||
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
|
||||||
|
|
||||||
def compile(function, *args, **kwargs):
|
|
||||||
|
|
||||||
if enable_torch_compile:
|
|
||||||
try:
|
|
||||||
return torch.compile(function, *args, **kwargs)
|
|
||||||
except RuntimeError:
|
|
||||||
return function
|
|
||||||
|
|
||||||
return function
|
|
||||||
@@ -1,11 +1,5 @@
|
|||||||
einops>=0.7.0
|
einops>=0.7.0
|
||||||
einops-exts
|
|
||||||
safetensors
|
|
||||||
huggingface_hub
|
huggingface_hub
|
||||||
transformers>=4.52.3
|
transformers>=4.52.3
|
||||||
k-diffusion>=0.1.1
|
|
||||||
alias-free-torch
|
|
||||||
descript-audio-codec
|
|
||||||
vector-quantize-pytorch
|
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
name: prismaudio-extract
|
|
||||||
channels:
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- python=3.10
|
|
||||||
- pip
|
|
||||||
- ffmpeg<7
|
|
||||||
- pip:
|
|
||||||
- torch>=2.6.0
|
|
||||||
- torchaudio>=2.6.0
|
|
||||||
- torchvision>=0.21.0
|
|
||||||
- tensorflow-cpu==2.15.0
|
|
||||||
- jax
|
|
||||||
- jaxlib
|
|
||||||
- transformers>=4.52.3
|
|
||||||
- decord
|
|
||||||
- einops>=0.7.0
|
|
||||||
- numpy
|
|
||||||
- mediapy
|
|
||||||
- git+https://github.com/google-deepmind/videoprism.git
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Standalone PrismAudio feature extraction script.
|
|
||||||
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
|
|
||||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
|
|
||||||
if _PLUGIN_DIR not in sys.path:
|
|
||||||
sys.path.insert(0, _PLUGIN_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
def _step(n, total, label):
|
|
||||||
"""Print step header and return start time."""
|
|
||||||
print(f"[extract] Step {n}/{total} — {label}...", flush=True)
|
|
||||||
return time.perf_counter()
|
|
||||||
|
|
||||||
|
|
||||||
def _done(t0, extra=""):
|
|
||||||
elapsed = time.perf_counter() - t0
|
|
||||||
suffix = f" {extra}" if extra else ""
|
|
||||||
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
t_total = time.perf_counter()
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
|
||||||
parser.add_argument("--video", required=True, help="Path to input video")
|
|
||||||
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
|
||||||
parser.add_argument("--output", required=True, help="Output .npz path")
|
|
||||||
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
|
||||||
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
|
||||||
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
|
||||||
parser.add_argument("--clip_fps", type=float, default=4.0)
|
|
||||||
parser.add_argument("--clip_size", type=int, default=288)
|
|
||||||
parser.add_argument("--sync_fps", type=float, default=25.0)
|
|
||||||
parser.add_argument("--sync_size", type=int, default=224)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"[extract] Python : {sys.executable}", flush=True)
|
|
||||||
print(f"[extract] Video : {args.video}", flush=True)
|
|
||||||
print(f"[extract] Output : {args.output}", flush=True)
|
|
||||||
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
|
|
||||||
|
|
||||||
if not os.path.exists(args.video):
|
|
||||||
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(1, 6, "importing dependencies")
|
|
||||||
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
|
||||||
import torchvision.transforms as T
|
|
||||||
_done(t0)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
|
|
||||||
feat_utils = FeaturesUtils(
|
|
||||||
vae_config_path=args.vae_config,
|
|
||||||
synchformer_ckpt=args.synchformer_ckpt,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
_done(t0)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(3, 6, "reading and preprocessing video")
|
|
||||||
if args.video.endswith(".npy"):
|
|
||||||
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
|
||||||
fps = args.source_fps
|
|
||||||
total_frames = all_frames.shape[0]
|
|
||||||
duration = total_frames / fps
|
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
|
||||||
clip_frames = all_frames[clip_indices]
|
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
|
||||||
sync_frames = all_frames[sync_indices]
|
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
|
||||||
else:
|
|
||||||
from decord import VideoReader, cpu
|
|
||||||
vr = VideoReader(args.video, ctx=cpu(0))
|
|
||||||
fps = vr.get_avg_fps()
|
|
||||||
total_frames = len(vr)
|
|
||||||
duration = total_frames / fps
|
|
||||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
|
||||||
|
|
||||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
|
||||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
|
||||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
|
||||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
|
||||||
|
|
||||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
|
||||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
|
||||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
|
||||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
|
||||||
|
|
||||||
clip_transform = T.Compose([
|
|
||||||
T.ToPILImage(),
|
|
||||||
T.Resize(args.clip_size),
|
|
||||||
T.CenterCrop(args.clip_size),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
sync_transform = T.Compose([
|
|
||||||
T.ToPILImage(),
|
|
||||||
T.Resize(args.sync_size),
|
|
||||||
T.CenterCrop(args.sync_size),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
])
|
|
||||||
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
|
||||||
_done(t0)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(4, 6, "encoding text with T5-Gemma")
|
|
||||||
text_features = feat_utils.encode_t5_text([args.cot_text])
|
|
||||||
_done(t0, f"shape={tuple(text_features.shape)}")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(5, 6, "encoding video with VideoPrism")
|
|
||||||
global_video_features, video_features, global_text_features = \
|
|
||||||
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
|
||||||
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = _step(6, 6, "encoding video with Synchformer")
|
|
||||||
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
|
||||||
_done(t0, f"shape={tuple(sync_features.shape)}")
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
print(f"[extract] Saving features to {args.output} ...", flush=True)
|
|
||||||
np.savez(
|
|
||||||
args.output,
|
|
||||||
video_features=video_features.cpu().float().numpy(),
|
|
||||||
global_video_features=global_video_features.cpu().float().numpy(),
|
|
||||||
text_features=text_features.cpu().float().numpy(),
|
|
||||||
global_text_features=global_text_features.cpu().float().numpy(),
|
|
||||||
sync_features=sync_features.cpu().float().numpy(),
|
|
||||||
caption_cot=args.cot_text,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
|
||||||
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# Install the PrismAudio feature-extraction environment using pip venv.
|
|
||||||
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# bash scripts/install_extract_env.sh [/path/to/venv]
|
|
||||||
#
|
|
||||||
# Default venv path: /opt/prismaudio-extract
|
|
||||||
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
|
|
||||||
# <venv>/bin/python (Linux/Mac)
|
|
||||||
# <venv>\Scripts\python.exe (Windows)
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
VENV_DIR="${1:-/opt/prismaudio-extract}"
|
|
||||||
|
|
||||||
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
|
|
||||||
python3 -m venv "${VENV_DIR}"
|
|
||||||
|
|
||||||
PIP="${VENV_DIR}/bin/pip"
|
|
||||||
|
|
||||||
echo "[PrismAudio] Upgrading pip..."
|
|
||||||
"${PIP}" install --upgrade pip
|
|
||||||
|
|
||||||
echo "[PrismAudio] Installing PyTorch stack..."
|
|
||||||
"${PIP}" install torch torchaudio torchvision
|
|
||||||
|
|
||||||
echo "[PrismAudio] Installing feature-extraction dependencies..."
|
|
||||||
"${PIP}" install \
|
|
||||||
"tensorflow-cpu>=2.16.0" \
|
|
||||||
"jax[cpu]" \
|
|
||||||
"jaxlib" \
|
|
||||||
"transformers" \
|
|
||||||
"decord" \
|
|
||||||
"einops" \
|
|
||||||
"numpy" \
|
|
||||||
"mediapy"
|
|
||||||
|
|
||||||
echo "[PrismAudio] Installing VideoPrism..."
|
|
||||||
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
|
|
||||||
echo " ${VENV_DIR}/bin/python"
|
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
# Vendored from https://github.com/jnwnlee/selva
|
||||||
|
# Pinned commit: d7d40a992aab58e7cf246055681a657e5d8b4a4d
|
||||||
|
# Imports rewritten from selva.* → selva_core.*
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from fractions import Fraction
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from av import AudioFrame
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoInfo:
|
||||||
|
duration_sec: float
|
||||||
|
fps: Fraction
|
||||||
|
clip_frames: torch.Tensor
|
||||||
|
sync_frames: torch.Tensor
|
||||||
|
all_frames: Optional[list[np.ndarray]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.all_frames[0].shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self):
|
||||||
|
return self.all_frames[0].shape[1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
||||||
|
fps: Fraction) -> 'VideoInfo':
|
||||||
|
num_frames = int(duration_sec * fps)
|
||||||
|
all_frames = [image_info.original_frame] * num_frames
|
||||||
|
return cls(duration_sec=duration_sec,
|
||||||
|
fps=fps,
|
||||||
|
clip_frames=image_info.clip_frames,
|
||||||
|
sync_frames=image_info.sync_frames,
|
||||||
|
all_frames=all_frames)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageInfo:
|
||||||
|
clip_frames: torch.Tensor
|
||||||
|
sync_frames: torch.Tensor
|
||||||
|
original_frame: Optional[np.ndarray]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.original_frame.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self):
|
||||||
|
return self.original_frame.shape[1]
|
||||||
|
|
||||||
|
|
||||||
|
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
||||||
|
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
||||||
|
output_frames = [[] for _ in list_of_fps]
|
||||||
|
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
||||||
|
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
||||||
|
all_frames = []
|
||||||
|
|
||||||
|
# container = av.open(video_path)
|
||||||
|
with av.open(video_path) as container:
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
fps = stream.guessed_rate
|
||||||
|
stream.thread_type = 'AUTO'
|
||||||
|
for packet in container.demux(stream):
|
||||||
|
for frame in packet.decode():
|
||||||
|
frame_time = frame.time
|
||||||
|
if frame_time < start_sec:
|
||||||
|
continue
|
||||||
|
if frame_time > end_sec:
|
||||||
|
break
|
||||||
|
|
||||||
|
frame_np = None
|
||||||
|
if need_all_frames:
|
||||||
|
frame_np = frame.to_ndarray(format='rgb24')
|
||||||
|
all_frames.append(frame_np)
|
||||||
|
|
||||||
|
for i, _ in enumerate(list_of_fps):
|
||||||
|
this_time = frame_time
|
||||||
|
while this_time >= next_frame_time_for_each_fps[i]:
|
||||||
|
if frame_np is None:
|
||||||
|
frame_np = frame.to_ndarray(format='rgb24')
|
||||||
|
|
||||||
|
output_frames[i].append(frame_np)
|
||||||
|
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
||||||
|
|
||||||
|
output_frames = [np.stack(frames) for frames in output_frames]
|
||||||
|
return output_frames, all_frames, fps
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_video_chunk(video_chunk: torch.Tensor,
|
||||||
|
expected_length: int,
|
||||||
|
*,
|
||||||
|
n_tolerance_frame: int = 1,
|
||||||
|
desc: str = "") \
|
||||||
|
-> torch.Tensor:
|
||||||
|
# video_chunk: [T, H, W, C]
|
||||||
|
if video_chunk.shape[0] < expected_length:
|
||||||
|
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
||||||
|
# copy the last frame to make it the right length
|
||||||
|
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
||||||
|
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
||||||
|
assert video_chunk.shape[0] == expected_length
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
||||||
|
)
|
||||||
|
video_chunk = video_chunk[:expected_length]
|
||||||
|
if video_chunk.shape[0] != expected_length:
|
||||||
|
raise RuntimeError(f'Video wrong length {desc}, '
|
||||||
|
f'expected {expected_length}, '
|
||||||
|
f'got {video_chunk.shape[0]}')
|
||||||
|
|
||||||
|
return video_chunk
|
||||||
|
|
||||||
|
|
||||||
|
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
||||||
|
sampling_rate: int):
|
||||||
|
container = av.open(output_path, 'w')
|
||||||
|
output_video_stream = container.add_stream('h264', video_info.fps)
|
||||||
|
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
||||||
|
output_video_stream.width = video_info.width
|
||||||
|
output_video_stream.height = video_info.height
|
||||||
|
output_video_stream.pix_fmt = 'yuv420p'
|
||||||
|
|
||||||
|
output_audio_stream = container.add_stream('aac', sampling_rate)
|
||||||
|
|
||||||
|
# encode video
|
||||||
|
for image in video_info.all_frames:
|
||||||
|
image = av.VideoFrame.from_ndarray(image)
|
||||||
|
packet = output_video_stream.encode(image)
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
for packet in output_video_stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
# convert float tensor audio to numpy array
|
||||||
|
audio_np = audio.numpy().astype(np.float32)
|
||||||
|
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||||
|
audio_frame.sample_rate = sampling_rate
|
||||||
|
|
||||||
|
for packet in output_audio_stream.encode(audio_frame):
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
for packet in output_audio_stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
|
||||||
|
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
||||||
|
"""
|
||||||
|
NOTE: I don't think we can get the exact video duration right without re-encoding
|
||||||
|
so we are not using this but keeping it here for reference
|
||||||
|
"""
|
||||||
|
video = av.open(video_path)
|
||||||
|
output = av.open(output_path, 'w')
|
||||||
|
input_video_stream = video.streams.video[0]
|
||||||
|
output_video_stream = output.add_stream(template=input_video_stream)
|
||||||
|
output_audio_stream = output.add_stream('aac', sampling_rate)
|
||||||
|
|
||||||
|
duration_sec = audio.shape[-1] / sampling_rate
|
||||||
|
|
||||||
|
for packet in video.demux(input_video_stream):
|
||||||
|
# We need to skip the "flushing" packets that `demux` generates.
|
||||||
|
if packet.dts is None:
|
||||||
|
continue
|
||||||
|
# We need to assign the packet to the new stream.
|
||||||
|
packet.stream = output_video_stream
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
# convert float tensor audio to numpy array
|
||||||
|
audio_np = audio.numpy().astype(np.float32)
|
||||||
|
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||||
|
audio_frame.sample_rate = sampling_rate
|
||||||
|
|
||||||
|
for packet in output_audio_stream.encode(audio_frame):
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
for packet in output_audio_stream.encode():
|
||||||
|
output.mux(packet)
|
||||||
|
|
||||||
|
video.close()
|
||||||
|
output.close()
|
||||||
|
|
||||||
|
output.close()
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from omegaconf import DictConfig, open_dict
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from selva_core.data.vgg_sound import VGGSound
|
||||||
|
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
|
||||||
|
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
|
||||||
|
from selva_core.data.eval.audiocaps import AudioCapsData
|
||||||
|
from selva_core.data.mm_dataset import MultiModalDataset
|
||||||
|
from selva_core.data.mixup import DataMixupCollate
|
||||||
|
from selva_core.utils.dist_utils import local_rank
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
# Re-seed randomness every time we start a worker
|
||||||
|
def worker_init_fn(worker_id: int):
|
||||||
|
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
||||||
|
|
||||||
|
|
||||||
|
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
|
||||||
|
) -> Dataset:
|
||||||
|
dataset = VGGSound(root=data_cfg.root,
|
||||||
|
tsv_path=data_cfg.subset_name,
|
||||||
|
sample_rate=16_000,
|
||||||
|
duration_sec=8.0,
|
||||||
|
normalize_audio=normalize_audio,
|
||||||
|
mmap_dir=data_cfg.memmap_dir,
|
||||||
|
tsv_tsynch_path=data_cfg.tsv_tsynch,
|
||||||
|
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
|
||||||
|
data_dim=cfg.data_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
||||||
|
raise NotImplementedError('Audio data loading is not implemented yet')
|
||||||
|
|
||||||
|
|
||||||
|
def setup_training_datasets(cfg: DictConfig,
|
||||||
|
generator: torch.Generator,
|
||||||
|
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||||
|
if cfg.mini_train:
|
||||||
|
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
|
||||||
|
dataset = MultiModalDataset([vgg], [])
|
||||||
|
if cfg.example_train:
|
||||||
|
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
|
||||||
|
dataset = MultiModalDataset([video], [])
|
||||||
|
else:
|
||||||
|
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
|
||||||
|
# load the largest one first
|
||||||
|
# you can add more video/audio data upon demand, such as
|
||||||
|
# clotho = load_audio_data(cfg, cfg.data.Clotho)
|
||||||
|
dataset = MultiModalDataset([vgg], [])
|
||||||
|
|
||||||
|
batch_size = cfg.batch_size
|
||||||
|
num_workers = cfg.num_workers
|
||||||
|
pin_memory = cfg.pin_memory
|
||||||
|
|
||||||
|
if cfg.mixup.domain == 'data':
|
||||||
|
mixup_params = cfg.mixup.params
|
||||||
|
collate_fn = DataMixupCollate(generator=generator,
|
||||||
|
**mixup_params)
|
||||||
|
else:
|
||||||
|
collate_fn = None
|
||||||
|
|
||||||
|
sampler, loader = construct_loader(dataset,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
return dataset, sampler, loader
|
||||||
|
|
||||||
|
|
||||||
|
def setup_test_datasets(cfg: DictConfig,
|
||||||
|
generator: torch.Generator,
|
||||||
|
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||||
|
if cfg.example_train:
|
||||||
|
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
|
||||||
|
elif cfg.dataset.startswith('vggsound'):
|
||||||
|
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
|
||||||
|
|
||||||
|
batch_size = cfg.batch_size
|
||||||
|
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||||
|
pin_memory = cfg.pin_memory
|
||||||
|
|
||||||
|
if cfg.mixup.domain == 'data':
|
||||||
|
mixup_config = cfg.mixup.params
|
||||||
|
collate_fn = DataMixupCollate(generator=generator,
|
||||||
|
**mixup_config)
|
||||||
|
else:
|
||||||
|
collate_fn = None
|
||||||
|
|
||||||
|
sampler, loader = construct_loader(dataset,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
return dataset, sampler, loader
|
||||||
|
|
||||||
|
|
||||||
|
def setup_val_datasets(cfg: DictConfig,
|
||||||
|
generator: torch.Generator,
|
||||||
|
) -> tuple[Dataset, DataLoader, DataLoader]:
|
||||||
|
if cfg.example_train:
|
||||||
|
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
|
||||||
|
else:
|
||||||
|
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
|
||||||
|
|
||||||
|
val_batch_size = cfg.batch_size
|
||||||
|
val_eval_batch_size = cfg.eval_batch_size
|
||||||
|
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||||
|
pin_memory = cfg.pin_memory
|
||||||
|
|
||||||
|
if cfg.mixup.domain == 'data':
|
||||||
|
mixup_config = cfg.mixup.params
|
||||||
|
collate_fn = DataMixupCollate(generator=generator,
|
||||||
|
**mixup_config)
|
||||||
|
else:
|
||||||
|
collate_fn = None
|
||||||
|
|
||||||
|
_, val_loader = construct_loader(dataset,
|
||||||
|
val_batch_size,
|
||||||
|
num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
_, eval_loader = construct_loader(dataset,
|
||||||
|
val_eval_batch_size,
|
||||||
|
num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=collate_fn)
|
||||||
|
|
||||||
|
return dataset, val_loader, eval_loader
|
||||||
|
|
||||||
|
|
||||||
|
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
||||||
|
if dataset_name.startswith('audiocaps_full'):
|
||||||
|
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
|
||||||
|
cfg.eval_data.audiocaps_full.csv_path)
|
||||||
|
elif dataset_name.startswith('audiocaps'):
|
||||||
|
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
|
||||||
|
cfg.eval_data.audiocaps.csv_path)
|
||||||
|
elif dataset_name.startswith('vggsound'):
|
||||||
|
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
|
||||||
|
cfg.eval_data.vggsound.csv_path,
|
||||||
|
duration_sec=cfg.duration_s)
|
||||||
|
elif dataset_name.startswith('infer_video'):
|
||||||
|
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
|
||||||
|
cfg.eval_data.infer_video.jsonl_path,
|
||||||
|
duration_sec=cfg.duration_s)
|
||||||
|
cfg.batch_size = 1
|
||||||
|
elif dataset_name.startswith('example_video'):
|
||||||
|
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
|
||||||
|
cfg.eval_data.Example_video.csv_path,
|
||||||
|
duration_sec=cfg.duration_s)
|
||||||
|
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
|
||||||
|
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
|
||||||
|
cfg.eval_data[dataset_name].csv_path,
|
||||||
|
duration_sec=cfg.duration_s)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
||||||
|
|
||||||
|
batch_size = cfg.batch_size
|
||||||
|
num_workers = cfg.num_workers
|
||||||
|
pin_memory = cfg.pin_memory
|
||||||
|
_, loader = construct_loader(dataset,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
error_avoidance=True)
|
||||||
|
return dataset, loader
|
||||||
|
|
||||||
|
|
||||||
|
def error_avoidance_collate(batch):
|
||||||
|
# Filter our None values
|
||||||
|
batch = [item for item in batch if item is not None]
|
||||||
|
if len(batch) == 0:
|
||||||
|
return None
|
||||||
|
return default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_loader(dataset: Dataset,
|
||||||
|
batch_size: int,
|
||||||
|
num_workers: int,
|
||||||
|
*,
|
||||||
|
shuffle: bool = True,
|
||||||
|
drop_last: bool = True,
|
||||||
|
pin_memory: bool = False,
|
||||||
|
error_avoidance: bool = False,
|
||||||
|
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
|
||||||
|
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
||||||
|
train_loader = DataLoader(dataset,
|
||||||
|
batch_size,
|
||||||
|
sampler=train_sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
drop_last=drop_last,
|
||||||
|
persistent_workers=num_workers > 0,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
|
||||||
|
return train_sampler, train_loader
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCapsData(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
||||||
|
df = pd.read_csv(csv_path).to_dict(orient='records')
|
||||||
|
|
||||||
|
audio_files = sorted(os.listdir(audio_path))
|
||||||
|
audio_files = set(
|
||||||
|
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
||||||
|
|
||||||
|
self.data = []
|
||||||
|
for row in df:
|
||||||
|
self.data.append({
|
||||||
|
'name': row['name'],
|
||||||
|
'caption': row['caption'],
|
||||||
|
})
|
||||||
|
|
||||||
|
self.audio_path = Path(audio_path)
|
||||||
|
self.csv_path = Path(csv_path)
|
||||||
|
|
||||||
|
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||||
|
return self.data[idx]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torio.io import StreamingMediaDecoder
|
||||||
|
|
||||||
|
from selva_core.data.av_utils import normalize_video_chunk
|
||||||
|
from selva_core.utils.dist_utils import local_rank
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_CLIP_FPS = 8.0
|
||||||
|
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_SYNC_FPS = 25.0
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_root: Union[str, Path],
|
||||||
|
*,
|
||||||
|
duration_sec: float = 8.0,
|
||||||
|
clip_video_required: bool = False,
|
||||||
|
):
|
||||||
|
self.video_root = Path(video_root)
|
||||||
|
self.duration_sec = duration_sec
|
||||||
|
self.clip_video_required = clip_video_required
|
||||||
|
|
||||||
|
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||||
|
self.sync_transform = v2.Compose([
|
||||||
|
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
# v2.CenterCrop(_SYNC_SIZE),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
if self.clip_video_required:
|
||||||
|
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||||
|
self.clip_transform = v2.Compose([
|
||||||
|
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
])
|
||||||
|
|
||||||
|
# to be implemented by subclasses
|
||||||
|
self.captions = {}
|
||||||
|
self.negative_captions = {}
|
||||||
|
self.videos = sorted(list(self.captions.keys()))
|
||||||
|
|
||||||
|
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
video_id = self.videos[idx]
|
||||||
|
caption = self.captions[video_id]
|
||||||
|
negative_caption = self.negative_captions.get(video_id, None)
|
||||||
|
|
||||||
|
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||||
|
frame_rate=_SYNC_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
if self.clip_video_required:
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||||
|
frame_rate=_CLIP_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
|
||||||
|
reader.fill_buffer()
|
||||||
|
data_chunk = reader.pop_chunks()
|
||||||
|
|
||||||
|
sync_chunk = data_chunk[0]
|
||||||
|
if sync_chunk is None:
|
||||||
|
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||||
|
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||||
|
n_tolerance_frame=3, desc=video_id)
|
||||||
|
sync_chunk = self.sync_transform(sync_chunk)
|
||||||
|
|
||||||
|
if self.clip_video_required:
|
||||||
|
clip_chunk = data_chunk[1]
|
||||||
|
if clip_chunk is None:
|
||||||
|
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||||
|
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||||
|
n_tolerance_frame=1, desc=video_id)
|
||||||
|
clip_chunk = self.clip_transform(clip_chunk)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'name': video_id,
|
||||||
|
'caption': caption,
|
||||||
|
'sync_video': sync_chunk,
|
||||||
|
}
|
||||||
|
if self.clip_video_required:
|
||||||
|
data['clip_video'] = clip_chunk
|
||||||
|
if negative_caption is not None:
|
||||||
|
data['negative_caption'] = negative_caption
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
try:
|
||||||
|
return self.sample(idx)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions)
|
||||||
|
|
||||||
|
|
||||||
|
class VGGSound(VideoDataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_root: Union[str, Path],
|
||||||
|
csv_path: Union[str, Path],
|
||||||
|
*,
|
||||||
|
duration_sec: float = 8.0,
|
||||||
|
clip_video_required: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(video_root, duration_sec=duration_sec,
|
||||||
|
clip_video_required=clip_video_required)
|
||||||
|
self.video_root = Path(video_root)
|
||||||
|
self.csv_path = Path(csv_path)
|
||||||
|
|
||||||
|
videos = sorted(os.listdir(self.video_root))
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {video_root}')
|
||||||
|
self.captions = {}
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
||||||
|
'split']).to_dict(orient='records')
|
||||||
|
|
||||||
|
videos_no_found = []
|
||||||
|
for row in df:
|
||||||
|
if row['split'] == 'test':
|
||||||
|
start_sec = int(row['sec'])
|
||||||
|
video_id = str(row['id'])
|
||||||
|
# this is how our videos are named
|
||||||
|
video_name = f'{video_id}_{start_sec:06d}'
|
||||||
|
if video_name + '.mp4' not in videos:
|
||||||
|
videos_no_found.append(video_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.captions[video_name] = row['caption']
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {video_root}')
|
||||||
|
log.info(f'{len(self.captions)} useable videos found')
|
||||||
|
if videos_no_found:
|
||||||
|
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
||||||
|
log.info(
|
||||||
|
'A small amount is expected, as not all videos are still available on YouTube')
|
||||||
|
|
||||||
|
self.videos = sorted(list(self.captions.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceVideoData(VideoDataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_root: Union[str, Path],
|
||||||
|
jsonl_root: Union[str, Path],
|
||||||
|
*,
|
||||||
|
duration_sec: float = 10.0,
|
||||||
|
clip_video_required: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(video_root, duration_sec=duration_sec,
|
||||||
|
clip_video_required=clip_video_required)
|
||||||
|
self.video_root = Path(video_root)
|
||||||
|
self.jsonl_root = Path(jsonl_root)
|
||||||
|
|
||||||
|
videos = sorted(os.listdir(self.video_root))
|
||||||
|
videos = [v[:-4] for v in videos] # remove extensions
|
||||||
|
self.captions = {}
|
||||||
|
|
||||||
|
for v in videos:
|
||||||
|
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.captions[v] = data['audio_prompt']
|
||||||
|
self.negative_captions[v] = data.get('negative_audio_prompt', None)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {video_root}')
|
||||||
|
|
||||||
|
self.videos = videos
|
||||||
|
|
||||||
|
|
||||||
|
class VGGMonoAudioBench(VideoDataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
video_root: Union[str, Path],
|
||||||
|
csv_path: Union[str, Path],
|
||||||
|
*,
|
||||||
|
duration_sec: float = 8.0,
|
||||||
|
clip_video_required: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(video_root, duration_sec=duration_sec,
|
||||||
|
clip_video_required=clip_video_required)
|
||||||
|
self.video_root = Path(video_root)
|
||||||
|
self.csv_path = Path(csv_path)
|
||||||
|
|
||||||
|
videos = sorted(os.listdir(self.video_root))
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {video_root}')
|
||||||
|
self.captions = {}
|
||||||
|
self.negative_captions = {}
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
|
||||||
|
).to_dict(orient='records')
|
||||||
|
|
||||||
|
videos_no_found = []
|
||||||
|
for row in df:
|
||||||
|
video_name = str(Path(row['file_name']).stem)
|
||||||
|
if video_name + '.mp4' not in videos:
|
||||||
|
videos_no_found.append(video_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.captions[video_name] = row['label']
|
||||||
|
self.negative_captions[video_name] = row['paired_label']
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {video_root}')
|
||||||
|
log.info(f'{len(self.captions)} useable videos found')
|
||||||
|
if videos_no_found:
|
||||||
|
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
|
||||||
|
|
||||||
|
self.videos = sorted(list(self.captions.keys()))
|
||||||
@@ -0,0 +1,194 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torio.io import StreamingMediaDecoder
|
||||||
|
|
||||||
|
from selva_core.data.av_utils import normalize_video_chunk
|
||||||
|
from selva_core.utils.dist_utils import local_rank
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_CLIP_FPS = 8.0
|
||||||
|
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_SYNC_FPS = 25.0
|
||||||
|
|
||||||
|
|
||||||
|
class VGGSound(Dataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: Union[str, Path],
|
||||||
|
*,
|
||||||
|
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||||
|
audio_required: bool = True,
|
||||||
|
sample_rate: int = 16_000,
|
||||||
|
duration_sec: float = 8.0,
|
||||||
|
audio_samples: Optional[int] = None,
|
||||||
|
normalize_audio: bool = False,
|
||||||
|
clip_video_required: bool = True,
|
||||||
|
):
|
||||||
|
self.root = Path(root)
|
||||||
|
self.audio_required = audio_required
|
||||||
|
if audio_required:
|
||||||
|
self.normalize_audio = normalize_audio
|
||||||
|
if audio_samples is None:
|
||||||
|
self.audio_samples = int(sample_rate * duration_sec)
|
||||||
|
else:
|
||||||
|
self.audio_samples = audio_samples
|
||||||
|
effective_duration = audio_samples / sample_rate
|
||||||
|
# make sure the duration is close enough, within 15ms
|
||||||
|
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||||
|
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||||
|
self.clip_video_required = clip_video_required
|
||||||
|
|
||||||
|
videos = sorted(os.listdir(self.root))
|
||||||
|
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||||
|
self.labels = {}
|
||||||
|
self.videos = []
|
||||||
|
missing_videos = []
|
||||||
|
|
||||||
|
# read the tsv for subset information
|
||||||
|
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||||
|
for record in df_list:
|
||||||
|
id = record['id']
|
||||||
|
label = record['label']
|
||||||
|
if id in videos:
|
||||||
|
self.labels[id] = label
|
||||||
|
self.videos.append(id)
|
||||||
|
else:
|
||||||
|
missing_videos.append(id)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {root}')
|
||||||
|
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||||
|
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.duration_sec = duration_sec
|
||||||
|
|
||||||
|
if audio_required:
|
||||||
|
self.expected_audio_length = self.audio_samples
|
||||||
|
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||||
|
if clip_video_required:
|
||||||
|
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||||
|
|
||||||
|
self.sync_transform = v2.Compose([
|
||||||
|
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
# v2.CenterCrop(_SYNC_SIZE),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
if clip_video_required:
|
||||||
|
self.clip_transform = v2.Compose([
|
||||||
|
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
])
|
||||||
|
if audio_required:
|
||||||
|
self.resampler = {}
|
||||||
|
|
||||||
|
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
video_id = self.videos[idx]
|
||||||
|
|
||||||
|
label = self.labels[video_id]
|
||||||
|
|
||||||
|
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||||
|
frame_rate=_SYNC_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
if self.audio_required:
|
||||||
|
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||||
|
if self.clip_video_required:
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||||
|
frame_rate=_CLIP_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
|
||||||
|
reader.fill_buffer()
|
||||||
|
data_chunk = reader.pop_chunks()
|
||||||
|
|
||||||
|
sync_chunk = data_chunk[0]
|
||||||
|
if sync_chunk is None:
|
||||||
|
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||||
|
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||||
|
n_tolerance_frame=3, desc=video_id)
|
||||||
|
sync_chunk = self.sync_transform(sync_chunk)
|
||||||
|
|
||||||
|
if self.audio_required:
|
||||||
|
audio_chunk = data_chunk[1]
|
||||||
|
|
||||||
|
if self.clip_video_required:
|
||||||
|
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||||
|
if clip_chunk is None:
|
||||||
|
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||||
|
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||||
|
n_tolerance_frame=1, desc=video_id)
|
||||||
|
clip_chunk = self.clip_transform(clip_chunk)
|
||||||
|
|
||||||
|
# process audio
|
||||||
|
if self.audio_required:
|
||||||
|
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||||
|
audio_chunk = audio_chunk.transpose(0, 1)
|
||||||
|
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||||
|
if self.normalize_audio:
|
||||||
|
abs_max = audio_chunk.abs().max()
|
||||||
|
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||||
|
if abs_max <= 1e-6:
|
||||||
|
raise RuntimeError(f'Audio is silent {video_id}')
|
||||||
|
|
||||||
|
# resample
|
||||||
|
if sample_rate == self.sample_rate:
|
||||||
|
audio_chunk = audio_chunk
|
||||||
|
else:
|
||||||
|
if sample_rate not in self.resampler:
|
||||||
|
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||||
|
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||||
|
sample_rate,
|
||||||
|
self.sample_rate,
|
||||||
|
lowpass_filter_width=64,
|
||||||
|
rolloff=0.9475937167399596,
|
||||||
|
resampling_method='sinc_interp_kaiser',
|
||||||
|
beta=14.769656459379492,
|
||||||
|
)
|
||||||
|
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||||
|
|
||||||
|
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||||
|
raise RuntimeError(f'Audio too short {video_id}')
|
||||||
|
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'id': video_id,
|
||||||
|
'caption': label,
|
||||||
|
'sync_video': sync_chunk,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.audio_required:
|
||||||
|
data['audio'] = audio_chunk
|
||||||
|
if self.clip_video_required:
|
||||||
|
data['clip_video'] = clip_chunk
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
try:
|
||||||
|
return self.sample(idx)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import open_clip
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class WavTextClipsDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: Union[str, Path],
|
||||||
|
*,
|
||||||
|
captions_tsv: Union[str, Path],
|
||||||
|
clips_tsv: Union[str, Path],
|
||||||
|
sample_rate: int,
|
||||||
|
num_samples: int,
|
||||||
|
normalize_audio: bool = False,
|
||||||
|
reject_silent: bool = False,
|
||||||
|
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
||||||
|
):
|
||||||
|
self.root = Path(root)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.normalize_audio = normalize_audio
|
||||||
|
self.reject_silent = reject_silent
|
||||||
|
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
||||||
|
|
||||||
|
audios = sorted(os.listdir(self.root))
|
||||||
|
audios = set([
|
||||||
|
Path(audio).stem for audio in audios
|
||||||
|
if audio.endswith('.wav') or audio.endswith('.flac')
|
||||||
|
])
|
||||||
|
self.captions = {}
|
||||||
|
|
||||||
|
# read the caption tsv
|
||||||
|
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
||||||
|
for record in df_list:
|
||||||
|
id = record['id']
|
||||||
|
caption = record['caption']
|
||||||
|
self.captions[id] = caption
|
||||||
|
|
||||||
|
# read the clip tsv
|
||||||
|
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
||||||
|
'id': str,
|
||||||
|
'name': str
|
||||||
|
}).to_dict('records')
|
||||||
|
self.clips = []
|
||||||
|
for record in df_list:
|
||||||
|
record['id'] = record['id']
|
||||||
|
record['name'] = record['name']
|
||||||
|
id = record['id']
|
||||||
|
name = record['name']
|
||||||
|
record['caption'] = self.captions[name]
|
||||||
|
self.clips.append(record)
|
||||||
|
|
||||||
|
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
||||||
|
|
||||||
|
self.resampler = {}
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||||
|
try:
|
||||||
|
clip = self.clips[idx]
|
||||||
|
audio_name = clip['name']
|
||||||
|
audio_id = clip['id']
|
||||||
|
caption = clip['caption']
|
||||||
|
start_sample = clip['start_sample']
|
||||||
|
end_sample = clip['end_sample']
|
||||||
|
|
||||||
|
audio_path = self.root / f'{audio_name}.flac'
|
||||||
|
if not audio_path.exists():
|
||||||
|
audio_path = self.root / f'{audio_name}.wav'
|
||||||
|
assert audio_path.exists()
|
||||||
|
|
||||||
|
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||||
|
abs_max = audio_chunk.abs().max()
|
||||||
|
if self.normalize_audio:
|
||||||
|
audio_chunk = audio_chunk / abs_max * 0.95
|
||||||
|
|
||||||
|
if self.reject_silent and abs_max < 1e-6:
|
||||||
|
log.warning(f'Rejecting silent audio')
|
||||||
|
return None
|
||||||
|
|
||||||
|
audio_chunk = audio_chunk[start_sample:end_sample]
|
||||||
|
|
||||||
|
# resample
|
||||||
|
if sample_rate == self.sample_rate:
|
||||||
|
audio_chunk = audio_chunk
|
||||||
|
else:
|
||||||
|
if sample_rate not in self.resampler:
|
||||||
|
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||||
|
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||||
|
sample_rate,
|
||||||
|
self.sample_rate,
|
||||||
|
lowpass_filter_width=64,
|
||||||
|
rolloff=0.9475937167399596,
|
||||||
|
resampling_method='sinc_interp_kaiser',
|
||||||
|
beta=14.769656459379492,
|
||||||
|
)
|
||||||
|
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||||
|
|
||||||
|
if audio_chunk.shape[0] < self.num_samples:
|
||||||
|
raise ValueError('Audio is too short')
|
||||||
|
audio_chunk = audio_chunk[:self.num_samples]
|
||||||
|
|
||||||
|
tokens = self.tokenizer([caption])[0]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'waveform': audio_chunk,
|
||||||
|
'id': audio_id,
|
||||||
|
'caption': caption,
|
||||||
|
'tokens': tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f'Error reading {audio_path}: {e}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.clips)
|
||||||
@@ -0,0 +1,338 @@
|
|||||||
|
""" Embedding Mixup
|
||||||
|
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
||||||
|
"""
|
||||||
|
from typing import Literal, Tuple, Union, List, Optional
|
||||||
|
from functools import partial
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataloader import default_collate
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from einops import rearrange
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
from selva_core.data.vgg_sound import _SYNC_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
class MixupBase:
|
||||||
|
""" Base class for mixup on either data or feature domain.
|
||||||
|
Applies different params to each element or whole batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||||
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||||
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||||
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||||
|
prob (float): Probability of applying mixup per batch or element
|
||||||
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||||
|
eps (float): Small epsilon value to avoid zero lambda
|
||||||
|
"""
|
||||||
|
def __init__(self, generator:torch.Generator,
|
||||||
|
*,
|
||||||
|
modality:Literal['video', 'audio', 'both'],
|
||||||
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||||
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||||
|
eps:float=0.05
|
||||||
|
):
|
||||||
|
self.modality = modality
|
||||||
|
self.mixup_lambda:float = mixup_lambda
|
||||||
|
self.mixup_alpha:float = mixup_alpha
|
||||||
|
self.mix_prob:float = prob
|
||||||
|
self.mode:str = mode
|
||||||
|
self.eps:float = eps
|
||||||
|
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
||||||
|
if generator.device.type == 'cuda':
|
||||||
|
self.generator_cuda = generator
|
||||||
|
generator_seed = generator.initial_seed()
|
||||||
|
self.generator = torch.Generator(device='cpu')
|
||||||
|
self.generator.manual_seed(generator_seed)
|
||||||
|
else:
|
||||||
|
self.generator = generator
|
||||||
|
|
||||||
|
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
||||||
|
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
||||||
|
if not self.mixup_alpha >= 0.:
|
||||||
|
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
||||||
|
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
||||||
|
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
||||||
|
|
||||||
|
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
||||||
|
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
||||||
|
if self.mixup_enabled:
|
||||||
|
if self.mixup_lambda < 1.: # constant lambda
|
||||||
|
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
||||||
|
elif self.mixup_alpha > 0.: # sampled lambda
|
||||||
|
# Use torch's beta distribution with generator
|
||||||
|
lam_mix = torch.distributions.Beta(
|
||||||
|
torch.tensor([self.mixup_alpha]),
|
||||||
|
torch.tensor([self.mixup_alpha]),
|
||||||
|
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
||||||
|
else:
|
||||||
|
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||||
|
lam_mix[lam_mix < self.eps] = self.eps
|
||||||
|
|
||||||
|
# Use torch's random with generator for the random comparison
|
||||||
|
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
||||||
|
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
||||||
|
return lam
|
||||||
|
|
||||||
|
def _params_per_batch(self) -> float:
|
||||||
|
lam:float = 1.
|
||||||
|
if self.mixup_enabled:
|
||||||
|
if self.mixup_lambda < 1.: # constant lambda
|
||||||
|
lam = self.mixup_lambda
|
||||||
|
elif self.mixup_alpha > 0.: # sampled lambda
|
||||||
|
lam = torch.distributions.Beta(
|
||||||
|
torch.tensor([self.mixup_alpha]),
|
||||||
|
torch.tensor([self.mixup_alpha]),
|
||||||
|
).sample().item()
|
||||||
|
else:
|
||||||
|
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||||
|
if lam < self.eps: lam = self.eps
|
||||||
|
lam = float(lam)
|
||||||
|
return lam
|
||||||
|
|
||||||
|
|
||||||
|
class DataMixupCollate(MixupBase):
|
||||||
|
""" Mixup video in data domain.
|
||||||
|
Applies different params to each element or whole batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||||
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||||
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||||
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||||
|
prob (float): Probability of applying mixup per batch or element
|
||||||
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||||
|
eps (float): Small epsilon value to avoid zero lambda
|
||||||
|
"""
|
||||||
|
def __init__(self, generator:torch.Generator,
|
||||||
|
*,
|
||||||
|
modality:Literal['video', 'audio', 'both']='video',
|
||||||
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||||
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||||
|
eps:float=0.05
|
||||||
|
):
|
||||||
|
super().__init__(generator, modality=modality,
|
||||||
|
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||||
|
mode=mode, eps=eps)
|
||||||
|
|
||||||
|
self.source_video_key= 'sync_video'
|
||||||
|
self.source_audio_key = 'audio'
|
||||||
|
self.target_video_key = 'sync_video_mixed'
|
||||||
|
self.target_audio_key = 'audio_mixed'
|
||||||
|
|
||||||
|
if not mode == 'batch':
|
||||||
|
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
||||||
|
self.sync_transform = v2.Compose([
|
||||||
|
v2.CenterCrop(_SYNC_SIZE),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
||||||
|
# only batch mode supported
|
||||||
|
batch_size:int = len(batch)
|
||||||
|
lam:float = self._params_per_batch()
|
||||||
|
|
||||||
|
if lam == 1.:
|
||||||
|
# no mixup, just return
|
||||||
|
for i in range(batch_size):
|
||||||
|
batch[i][target_key] = batch[i][source_key]
|
||||||
|
return lam
|
||||||
|
|
||||||
|
# Randomly choose between horizontal and vertical resizing using
|
||||||
|
orig_size = int(lam * _SYNC_SIZE)
|
||||||
|
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
||||||
|
if is_horizontal:
|
||||||
|
# Horizontal resize
|
||||||
|
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
||||||
|
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
||||||
|
else:
|
||||||
|
# Vertical resize
|
||||||
|
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
||||||
|
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
||||||
|
sync_resize_orig = v2.Compose([
|
||||||
|
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
])
|
||||||
|
sync_resize_pair = v2.Compose([
|
||||||
|
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
])
|
||||||
|
|
||||||
|
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
||||||
|
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
||||||
|
# (B, T, C, H, W)
|
||||||
|
# pass through resize, transform and concat
|
||||||
|
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
||||||
|
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
||||||
|
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
||||||
|
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
||||||
|
|
||||||
|
num_mixup = int(self.mix_prob * batch_size)
|
||||||
|
for i in range(num_mixup):
|
||||||
|
batch[i][target_key] = batch_videos_concat[i]
|
||||||
|
for i in range(num_mixup, batch_size):
|
||||||
|
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||||
|
|
||||||
|
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return lam
|
||||||
|
|
||||||
|
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
||||||
|
normalize:bool = True) -> float:
|
||||||
|
# assume source_key audios are normalized
|
||||||
|
batch_size:int = len(batch)
|
||||||
|
lam:float = self._params_per_batch()
|
||||||
|
|
||||||
|
if lam == 1.:
|
||||||
|
# no mixup, just return
|
||||||
|
for i in range(batch_size):
|
||||||
|
batch[i][target_key] = batch[i][source_key]
|
||||||
|
return lam
|
||||||
|
|
||||||
|
num_mixup = int(self.mix_prob * batch_size)
|
||||||
|
for i in range(num_mixup):
|
||||||
|
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
||||||
|
if normalize:
|
||||||
|
source_abs_max = batch[i][source_key].abs().max()
|
||||||
|
target_abs_max = batch[i][target_key].abs().max()
|
||||||
|
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
||||||
|
for i in range(num_mixup, batch_size):
|
||||||
|
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||||
|
|
||||||
|
return lam
|
||||||
|
|
||||||
|
def __call__(self, batch:list, _=None) -> torch.tensor:
|
||||||
|
batch_size:int = len(batch)
|
||||||
|
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
||||||
|
half = 'half' in self.mode
|
||||||
|
if half:
|
||||||
|
batch_size //= 2
|
||||||
|
|
||||||
|
if self.modality == 'video' or self.modality == 'both':
|
||||||
|
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
||||||
|
if self.modality == 'audio' or self.modality == 'both':
|
||||||
|
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
||||||
|
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
||||||
|
|
||||||
|
return default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMixup(MixupBase):
|
||||||
|
""" Mixup video in feature domain.
|
||||||
|
Applies different params to each element or whole batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||||
|
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||||
|
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||||
|
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||||
|
prob (float): Probability of applying mixup per batch or element
|
||||||
|
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||||
|
eps (float): Small epsilon value to avoid zero lambda
|
||||||
|
"""
|
||||||
|
def __init__(self, generator:torch.Generator,
|
||||||
|
*,
|
||||||
|
modality:Literal['video', 'audio', 'both']='video',
|
||||||
|
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||||
|
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||||
|
eps:float=0.05
|
||||||
|
):
|
||||||
|
super().__init__(generator, modality=modality,
|
||||||
|
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||||
|
mode=mode, eps=eps)
|
||||||
|
self.source_video_key= 'sync_f_vid_orig'
|
||||||
|
self.source_audio_key = 'sync_f_aud_orig'
|
||||||
|
self.target_video_key = 'sync_f_vid_mixed'
|
||||||
|
self.target_audio_key = 'sync_f_aud_mixed'
|
||||||
|
|
||||||
|
def _mix_elem_collate(self, batch:dict,
|
||||||
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
||||||
|
half:bool=False) -> torch.tensor:
|
||||||
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||||
|
batch_size:int = len(batch['id'])
|
||||||
|
num_elem:int = batch_size // 2 if half else batch_size
|
||||||
|
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
||||||
|
|
||||||
|
indices = torch.arange(num_elem)
|
||||||
|
mix_indices = batch_size - indices - 1
|
||||||
|
mix_mask = lam_batch < 1
|
||||||
|
active_indices = indices[mix_mask]
|
||||||
|
active_mix_indices = mix_indices[mix_mask]
|
||||||
|
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||||
|
for target_key, source_key in zip(target_keys, source_keys):
|
||||||
|
batch[target_key][active_indices] = (
|
||||||
|
batch[source_key][active_indices] * active_lambdas +
|
||||||
|
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||||
|
)
|
||||||
|
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||||
|
if half:
|
||||||
|
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
||||||
|
return lam_batch.unsqueeze(1)
|
||||||
|
|
||||||
|
def _mix_pair_collate(self, batch:dict,
|
||||||
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
||||||
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||||
|
batch_size:int = len(batch['id'])
|
||||||
|
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
||||||
|
|
||||||
|
indices = torch.arange(batch_size // 2)
|
||||||
|
mix_indices = batch_size - indices - 1
|
||||||
|
mix_mask = lam_batch < 1
|
||||||
|
active_indices = indices[mix_mask]
|
||||||
|
active_mix_indices = mix_indices[mix_mask]
|
||||||
|
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||||
|
for target_key, source_key in zip(target_keys, source_keys):
|
||||||
|
batch[target_key][active_indices] = (
|
||||||
|
batch[source_key][active_indices] * active_lambdas +
|
||||||
|
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||||
|
)
|
||||||
|
batch[target_key][active_mix_indices] = (
|
||||||
|
batch[source_key][active_mix_indices] * active_lambdas +
|
||||||
|
batch[source_key][active_indices] * (1 - active_lambdas)
|
||||||
|
)
|
||||||
|
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||||
|
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
||||||
|
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
||||||
|
return lam_batch.unsqueeze(1)
|
||||||
|
|
||||||
|
def _mix_batch_collate(self, batch:dict,
|
||||||
|
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
||||||
|
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||||
|
lam:float = self._params_per_batch()
|
||||||
|
|
||||||
|
for target_key, source_key in zip(target_keys, source_keys):
|
||||||
|
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
||||||
|
flipped_source = torch.flip(batch[source_key], dims=[0])
|
||||||
|
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
||||||
|
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
||||||
|
return lam
|
||||||
|
|
||||||
|
def __call__(self, batch:dict, _=None) -> None:
|
||||||
|
batch_size:int = len(batch['id'])
|
||||||
|
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
||||||
|
half = 'half' in self.mode
|
||||||
|
if half:
|
||||||
|
batch_size //= 2
|
||||||
|
|
||||||
|
# Mixup
|
||||||
|
if self.mode == 'elem' or self.mode == 'half':
|
||||||
|
collate_fn = partial(self._mix_elem_collate, half=half)
|
||||||
|
elif self.mode == 'pair':
|
||||||
|
collate_fn = self._mix_pair_collate
|
||||||
|
else:
|
||||||
|
collate_fn = self._mix_batch_collate
|
||||||
|
|
||||||
|
if self.modality == 'both':
|
||||||
|
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
||||||
|
elif self.modality == 'video':
|
||||||
|
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
||||||
|
elif self.modality == 'audio':
|
||||||
|
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
||||||
|
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
||||||
|
|
||||||
|
# return batch
|
||||||
|
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import bisect
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
||||||
|
class MultiModalDataset(Dataset):
|
||||||
|
datasets: list[Dataset]
|
||||||
|
cumulative_sizes: list[int]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cumsum(sequence):
|
||||||
|
r, s = [], 0
|
||||||
|
for e in sequence:
|
||||||
|
l = len(e)
|
||||||
|
r.append(l + s)
|
||||||
|
s += l
|
||||||
|
return r
|
||||||
|
|
||||||
|
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
||||||
|
super().__init__()
|
||||||
|
self.video_datasets = list(video_datasets)
|
||||||
|
self.audio_datasets = list(audio_datasets)
|
||||||
|
self.datasets = self.video_datasets + self.audio_datasets
|
||||||
|
|
||||||
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.cumulative_sizes[-1]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if idx < 0:
|
||||||
|
if -idx > len(self):
|
||||||
|
raise ValueError("absolute value of index should not exceed dataset length")
|
||||||
|
idx = len(self) + idx
|
||||||
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||||
|
if dataset_idx == 0:
|
||||||
|
sample_idx = idx
|
||||||
|
else:
|
||||||
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||||
|
return self.datasets[dataset_idx][sample_idx]
|
||||||
|
|
||||||
|
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.video_datasets[0].compute_latent_stats()
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from tensordict import MemoryMappedTensor
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from selva_core.utils.dist_utils import local_rank, world_size
|
||||||
|
|
||||||
|
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
||||||
|
shm_path = Path('/dev/shm')
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def reseed(seed):
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def local_scatter_torch(obj: Optional[Any]):
|
||||||
|
if world_size == 1:
|
||||||
|
# Just one worker. Do nothing.
|
||||||
|
return obj
|
||||||
|
|
||||||
|
array = [obj] * world_size
|
||||||
|
target_array = [None]
|
||||||
|
if local_rank == 0:
|
||||||
|
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
||||||
|
else:
|
||||||
|
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
||||||
|
return target_array[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ShardDataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, root):
|
||||||
|
self.root = root
|
||||||
|
self.shards = sorted(os.listdir(root))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.shards)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tmp_dir(in_memory: bool) -> Path:
|
||||||
|
return shm_path if in_memory else scratch_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
||||||
|
in_memory: bool) -> MemoryMappedTensor:
|
||||||
|
if local_rank == 0:
|
||||||
|
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
||||||
|
log.info(f'Loading shards from {data_path} into {f.name}...')
|
||||||
|
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
||||||
|
data = share_tensor_to_all(data)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
f.close() # why does the context manager not close the file for me?
|
||||||
|
else:
|
||||||
|
log.info('Waiting for the data to be shared with me...')
|
||||||
|
data = share_tensor_to_all(None)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_shards(
|
||||||
|
data_path: Union[str, Path],
|
||||||
|
ids: list[int],
|
||||||
|
*,
|
||||||
|
tmp_file_path: str,
|
||||||
|
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
id_set = set(ids)
|
||||||
|
shards = sorted(os.listdir(data_path))
|
||||||
|
log.info(f'Found {len(shards)} shards in {data_path}.')
|
||||||
|
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
||||||
|
|
||||||
|
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
||||||
|
first_item = next(iter(first_shard.values()))
|
||||||
|
log.info(f'First item shape: {first_item.shape}')
|
||||||
|
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
||||||
|
dtype=torch.float32,
|
||||||
|
filename=tmp_file_path,
|
||||||
|
existsok=True)
|
||||||
|
total_count = 0
|
||||||
|
used_index = set()
|
||||||
|
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
||||||
|
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
||||||
|
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
||||||
|
for data in tqdm(loader, desc='Loading shards'):
|
||||||
|
for i, v in data.items():
|
||||||
|
if i not in id_set:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# tensor_index = ids.index(i)
|
||||||
|
tensor_index = id_indexing[i]
|
||||||
|
if tensor_index in used_index:
|
||||||
|
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
||||||
|
used_index.add(tensor_index)
|
||||||
|
mm_tensor[tensor_index] = v
|
||||||
|
total_count += 1
|
||||||
|
|
||||||
|
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
||||||
|
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
||||||
|
|
||||||
|
return mm_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
||||||
|
"""
|
||||||
|
x: the tensor to be shared; None if local_rank != 0
|
||||||
|
return: the shared tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
||||||
|
if world_size == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
assert x is not None, 'x must not be None if local_rank == 0'
|
||||||
|
else:
|
||||||
|
assert x is None, 'x must be None if local_rank != 0'
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
filename = x.filename
|
||||||
|
meta_information = (filename, x.shape, x.dtype)
|
||||||
|
else:
|
||||||
|
meta_information = None
|
||||||
|
|
||||||
|
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
||||||
|
if local_rank == 0:
|
||||||
|
data = x
|
||||||
|
else:
|
||||||
|
data = MemoryMappedTensor.from_filename(filename=filename,
|
||||||
|
dtype=data_type,
|
||||||
|
shape=data_shape)
|
||||||
|
|
||||||
|
return data
|
||||||
@@ -0,0 +1,299 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from torio.io import StreamingMediaDecoder
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
from selva_core.data.av_utils import normalize_video_chunk
|
||||||
|
from selva_core.utils.dist_utils import local_rank
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
_CLIP_SIZE = 384
|
||||||
|
_CLIP_FPS = 8.0
|
||||||
|
|
||||||
|
_SYNC_SIZE = 224
|
||||||
|
_SYNC_FPS = 25.0
|
||||||
|
|
||||||
|
|
||||||
|
class VGGSound(Dataset):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
root: Union[str, Path],
|
||||||
|
*,
|
||||||
|
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||||
|
for_generator: bool = True,
|
||||||
|
audio_required: bool = False,
|
||||||
|
sample_rate: int = 16_000,
|
||||||
|
duration_sec: float = 8.0,
|
||||||
|
audio_samples: Optional[int] = None,
|
||||||
|
normalize_audio: bool = False,
|
||||||
|
clip_video_required: bool = False,
|
||||||
|
mmap_dir: Union[str, Path] = None,
|
||||||
|
tsv_tsynch_path: Union[str, Path] = None,
|
||||||
|
mmap_tsync_dir: Union[str, Path] = None,
|
||||||
|
data_dim: dict[str, int] = None,
|
||||||
|
):
|
||||||
|
self.root = Path(root)
|
||||||
|
self.audio_required = audio_required
|
||||||
|
if audio_required:
|
||||||
|
self.normalize_audio = normalize_audio
|
||||||
|
if audio_samples is None:
|
||||||
|
self.audio_samples = int(sample_rate * duration_sec)
|
||||||
|
else:
|
||||||
|
self.audio_samples = audio_samples
|
||||||
|
effective_duration = audio_samples / sample_rate
|
||||||
|
# make sure the duration is close enough, within 15ms
|
||||||
|
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||||
|
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||||
|
self.clip_video_required = clip_video_required
|
||||||
|
self.for_generator = for_generator
|
||||||
|
|
||||||
|
videos = sorted(os.listdir(self.root))
|
||||||
|
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||||
|
self.labels = {}
|
||||||
|
self.videos = []
|
||||||
|
missing_videos = []
|
||||||
|
|
||||||
|
# read the tsv for subset information
|
||||||
|
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||||
|
for record in df_list:
|
||||||
|
id = record['id']
|
||||||
|
label = record['label']
|
||||||
|
if id in videos:
|
||||||
|
self.labels[id] = label
|
||||||
|
self.videos.append(id)
|
||||||
|
else:
|
||||||
|
missing_videos.append(id)
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'{len(videos)} videos found in {root}')
|
||||||
|
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||||
|
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.duration_sec = duration_sec
|
||||||
|
|
||||||
|
if audio_required:
|
||||||
|
self.expected_audio_length = self.audio_samples
|
||||||
|
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||||
|
if clip_video_required:
|
||||||
|
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||||
|
|
||||||
|
self.sync_transform = v2.Compose([
|
||||||
|
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
# v2.CenterCrop(_SYNC_SIZE),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
])
|
||||||
|
|
||||||
|
if clip_video_required:
|
||||||
|
self.clip_transform = v2.Compose([
|
||||||
|
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||||
|
v2.ToImage(),
|
||||||
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
|
])
|
||||||
|
if audio_required:
|
||||||
|
self.resampler = {}
|
||||||
|
|
||||||
|
# mmap
|
||||||
|
log.info(f'Loading precomputed mmap from {mmap_dir}')
|
||||||
|
mmap_dir = Path(mmap_dir)
|
||||||
|
td = TensorDict.load_memmap(mmap_dir)
|
||||||
|
log.info(f'Loaded precomputed mmap from {mmap_dir}')
|
||||||
|
self.sync_features = td['sync_features']
|
||||||
|
if for_generator:
|
||||||
|
self.mean = td['mean']
|
||||||
|
self.std = td['std']
|
||||||
|
self.text_clip_features = td['text_features']
|
||||||
|
if clip_video_required:
|
||||||
|
self.clip_features = td['clip_features']
|
||||||
|
else:
|
||||||
|
self.clip_features = None
|
||||||
|
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
|
||||||
|
|
||||||
|
mmap_tsync_dir = Path(mmap_tsync_dir)
|
||||||
|
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
|
||||||
|
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
|
||||||
|
self.text_features = td_tsync['text_features']
|
||||||
|
self.text_masks = td_tsync['text_masks']
|
||||||
|
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
|
||||||
|
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
log.info(f'Loaded {len(self)} samples.')
|
||||||
|
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
||||||
|
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
||||||
|
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
|
||||||
|
if for_generator:
|
||||||
|
log.info(f'Loaded mean: {self.mean.shape}.')
|
||||||
|
log.info(f'Loaded std: {self.std.shape}.')
|
||||||
|
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
|
||||||
|
if clip_video_required:
|
||||||
|
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
||||||
|
|
||||||
|
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
|
||||||
|
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
|
||||||
|
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||||
|
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||||
|
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||||
|
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||||
|
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
|
||||||
|
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
|
||||||
|
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
|
||||||
|
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
|
||||||
|
if for_generator:
|
||||||
|
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
|
||||||
|
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||||
|
assert self.std.shape[1] == data_dim['latent_seq_len'], \
|
||||||
|
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||||
|
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
|
||||||
|
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
|
||||||
|
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
|
||||||
|
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
|
||||||
|
if clip_video_required:
|
||||||
|
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
|
||||||
|
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
|
||||||
|
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
|
||||||
|
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
|
||||||
|
|
||||||
|
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
||||||
|
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
|
||||||
|
latents = self.mean
|
||||||
|
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
||||||
|
|
||||||
|
def get_memory_mapped_tensor(self) -> TensorDict:
|
||||||
|
td = TensorDict({
|
||||||
|
'sync_features': self.sync_features,
|
||||||
|
'text_features': self.text_features,
|
||||||
|
'text_masks': self.text_masks,
|
||||||
|
})
|
||||||
|
if self.for_generator:
|
||||||
|
td['mean'] = self.mean
|
||||||
|
td['std'] = self.std
|
||||||
|
td['text_clip_features'] = self.text_clip_features
|
||||||
|
if self.clip_video_required:
|
||||||
|
td['clip_features'] = self.clip_features
|
||||||
|
return td
|
||||||
|
|
||||||
|
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
video_id = self.videos[idx]
|
||||||
|
|
||||||
|
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
|
||||||
|
label = self.captions[video_id]
|
||||||
|
else:
|
||||||
|
label = self.labels[video_id]
|
||||||
|
|
||||||
|
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||||
|
frame_rate=_SYNC_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
if self.audio_required:
|
||||||
|
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||||
|
if self.clip_video_required:
|
||||||
|
reader.add_basic_video_stream(
|
||||||
|
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||||
|
frame_rate=_CLIP_FPS,
|
||||||
|
format='rgb24',
|
||||||
|
)
|
||||||
|
|
||||||
|
reader.fill_buffer()
|
||||||
|
data_chunk = reader.pop_chunks()
|
||||||
|
|
||||||
|
sync_chunk = data_chunk[0]
|
||||||
|
if sync_chunk is None:
|
||||||
|
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||||
|
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||||
|
n_tolerance_frame=3, desc=video_id)
|
||||||
|
sync_chunk = self.sync_transform(sync_chunk)
|
||||||
|
|
||||||
|
if self.audio_required:
|
||||||
|
audio_chunk = data_chunk[1]
|
||||||
|
|
||||||
|
if self.clip_video_required:
|
||||||
|
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||||
|
if clip_chunk is None:
|
||||||
|
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||||
|
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||||
|
n_tolerance_frame=1, desc=video_id)
|
||||||
|
clip_chunk = self.clip_transform(clip_chunk)
|
||||||
|
|
||||||
|
# process audio
|
||||||
|
if self.audio_required:
|
||||||
|
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||||
|
audio_chunk = audio_chunk.transpose(0, 1)
|
||||||
|
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||||
|
if self.normalize_audio:
|
||||||
|
abs_max = audio_chunk.abs().max()
|
||||||
|
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||||
|
if abs_max <= 1e-6:
|
||||||
|
raise RuntimeError(f'Audio is silent {video_id}')
|
||||||
|
|
||||||
|
# resample
|
||||||
|
if sample_rate == self.sample_rate:
|
||||||
|
audio_chunk = audio_chunk
|
||||||
|
else:
|
||||||
|
if sample_rate not in self.resampler:
|
||||||
|
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||||
|
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||||
|
sample_rate,
|
||||||
|
self.sample_rate,
|
||||||
|
lowpass_filter_width=64,
|
||||||
|
rolloff=0.9475937167399596,
|
||||||
|
resampling_method='sinc_interp_kaiser',
|
||||||
|
beta=14.769656459379492,
|
||||||
|
)
|
||||||
|
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||||
|
|
||||||
|
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||||
|
raise RuntimeError(f'Audio too short {video_id}')
|
||||||
|
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'id': video_id,
|
||||||
|
'caption': label,
|
||||||
|
'sync_video': sync_chunk,
|
||||||
|
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
|
||||||
|
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
|
||||||
|
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
|
||||||
|
'video_exist': self.video_exist,
|
||||||
|
'text_exist': self.text_exist,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.for_generator:
|
||||||
|
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
|
||||||
|
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
|
||||||
|
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
|
||||||
|
|
||||||
|
if self.audio_required:
|
||||||
|
data['audio'] = audio_chunk
|
||||||
|
|
||||||
|
if self.clip_video_required:
|
||||||
|
data['clip_video'] = clip_chunk
|
||||||
|
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
try:
|
||||||
|
return self.sample(idx)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .autoencoder import AutoEncoderModule
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from selva_core.ext.autoencoder.vae import VAE, get_my_vae
|
||||||
|
from selva_core.ext.bigvgan import BigVGAN
|
||||||
|
from selva_core.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
||||||
|
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
|
||||||
|
class AutoEncoderModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
vae_ckpt_path,
|
||||||
|
vocoder_ckpt_path: Optional[str] = None,
|
||||||
|
mode: Literal['16k', '44k'],
|
||||||
|
need_vae_encoder: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
self.vae: VAE = get_my_vae(mode).eval()
|
||||||
|
vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
|
||||||
|
self.vae.load_state_dict(vae_state_dict)
|
||||||
|
self.vae.remove_weight_norm()
|
||||||
|
|
||||||
|
if mode == '16k':
|
||||||
|
assert vocoder_ckpt_path is not None
|
||||||
|
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
||||||
|
elif mode == '44k':
|
||||||
|
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
||||||
|
use_cuda_kernel=False)
|
||||||
|
self.vocoder.remove_weight_norm()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown mode: {mode}')
|
||||||
|
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if not need_vae_encoder:
|
||||||
|
del self.vae.encoder
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
||||||
|
return self.vae.encode(x)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.vae.decode(z)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.vocoder(spec)
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# This work is licensed under a Creative Commons
|
||||||
|
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
||||||
|
# You should have received a copy of the license along with this
|
||||||
|
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||||
|
"""Improved diffusion model architecture proposed in the paper
|
||||||
|
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Variant of constant() that inherits dtype and device from the given
|
||||||
|
# reference tensor by default.
|
||||||
|
|
||||||
|
_constant_cache = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||||
|
value = np.asarray(value)
|
||||||
|
if shape is not None:
|
||||||
|
shape = tuple(shape)
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if device is None:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
if memory_format is None:
|
||||||
|
memory_format = torch.contiguous_format
|
||||||
|
|
||||||
|
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||||
|
tensor = _constant_cache.get(key, None)
|
||||||
|
if tensor is None:
|
||||||
|
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||||
|
if shape is not None:
|
||||||
|
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||||
|
tensor = tensor.contiguous(memory_format=memory_format)
|
||||||
|
_constant_cache[key] = tensor
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
||||||
|
if dtype is None:
|
||||||
|
dtype = ref.dtype
|
||||||
|
if device is None:
|
||||||
|
device = ref.device
|
||||||
|
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Normalize given tensor to unit magnitude with respect to the given
|
||||||
|
# dimensions. Default = all dimensions except the first.
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x, dim=None, eps=1e-4):
|
||||||
|
if dim is None:
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||||
|
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
||||||
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim=None, eps=1e-4):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return normalize(x, dim=self.dim, eps=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Upsample or downsample the given tensor with the given filter,
|
||||||
|
# or keep it as is.
|
||||||
|
|
||||||
|
|
||||||
|
def resample(x, f=[1, 1], mode='keep'):
|
||||||
|
if mode == 'keep':
|
||||||
|
return x
|
||||||
|
f = np.float32(f)
|
||||||
|
assert f.ndim == 1 and len(f) % 2 == 0
|
||||||
|
pad = (len(f) - 1) // 2
|
||||||
|
f = f / f.sum()
|
||||||
|
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
||||||
|
f = const_like(x, f)
|
||||||
|
c = x.shape[1]
|
||||||
|
if mode == 'down':
|
||||||
|
return torch.nn.functional.conv2d(x,
|
||||||
|
f.tile([c, 1, 1, 1]),
|
||||||
|
groups=c,
|
||||||
|
stride=2,
|
||||||
|
padding=(pad, ))
|
||||||
|
assert mode == 'up'
|
||||||
|
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
|
||||||
|
groups=c,
|
||||||
|
stride=2,
|
||||||
|
padding=(pad, ))
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Magnitude-preserving SiLU (Equation 81).
|
||||||
|
|
||||||
|
|
||||||
|
def mp_silu(x):
|
||||||
|
return torch.nn.functional.silu(x) / 0.596
|
||||||
|
|
||||||
|
|
||||||
|
class MPSiLU(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return mp_silu(x)
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Magnitude-preserving sum (Equation 88).
|
||||||
|
|
||||||
|
|
||||||
|
def mp_sum(a, b, t=0.5):
|
||||||
|
return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Magnitude-preserving concatenation (Equation 103).
|
||||||
|
|
||||||
|
|
||||||
|
def mp_cat(a, b, dim=1, t=0.5):
|
||||||
|
Na = a.shape[dim]
|
||||||
|
Nb = b.shape[dim]
|
||||||
|
C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
|
||||||
|
wa = C / np.sqrt(Na) * (1 - t)
|
||||||
|
wb = C / np.sqrt(Nb) * t
|
||||||
|
return torch.cat([wa * a, wb * b], dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
#----------------------------------------------------------------------------
|
||||||
|
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
||||||
|
# with force weight normalization (Equation 66).
|
||||||
|
|
||||||
|
|
||||||
|
class MPConv1D(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size):
|
||||||
|
super().__init__()
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
|
||||||
|
|
||||||
|
self.weight_norm_removed = False
|
||||||
|
|
||||||
|
def forward(self, x, gain=1):
|
||||||
|
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
||||||
|
|
||||||
|
w = self.weight * gain
|
||||||
|
if w.ndim == 2:
|
||||||
|
return x @ w.t()
|
||||||
|
assert w.ndim == 3
|
||||||
|
return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
w = self.weight.to(torch.float32)
|
||||||
|
w = normalize(w) # traditional weight normalization
|
||||||
|
w = w / np.sqrt(w[0].numel())
|
||||||
|
w = w.to(self.weight.dtype)
|
||||||
|
self.weight.data.copy_(w)
|
||||||
|
|
||||||
|
self.weight_norm_removed = True
|
||||||
|
return self
|
||||||
@@ -0,0 +1,369 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from selva_core.ext.autoencoder.edm2_utils import MPConv1D
|
||||||
|
from selva_core.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||||
|
Upsample1D, nonlinearity)
|
||||||
|
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
DATA_MEAN_80D = [
|
||||||
|
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||||
|
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||||
|
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||||
|
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||||
|
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||||
|
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||||
|
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||||
|
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_STD_80D = [
|
||||||
|
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||||
|
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||||
|
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||||
|
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||||
|
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||||
|
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||||
|
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_MEAN_128D = [
|
||||||
|
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||||
|
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||||
|
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||||
|
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||||
|
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||||
|
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||||
|
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||||
|
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||||
|
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||||
|
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||||
|
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||||
|
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||||
|
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_STD_128D = [
|
||||||
|
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||||
|
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||||
|
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||||
|
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||||
|
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||||
|
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||||
|
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||||
|
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||||
|
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||||
|
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||||
|
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
data_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if data_dim == 80:
|
||||||
|
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||||
|
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||||
|
elif data_dim == 128:
|
||||||
|
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||||
|
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||||
|
|
||||||
|
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||||
|
self.data_std = self.data_std.view(1, -1, 1)
|
||||||
|
|
||||||
|
self.encoder = Encoder1D(
|
||||||
|
dim=hidden_dim,
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_layers=[3],
|
||||||
|
down_layers=[0],
|
||||||
|
in_dim=data_dim,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder1D(
|
||||||
|
dim=hidden_dim,
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_layers=[3],
|
||||||
|
down_layers=[0],
|
||||||
|
in_dim=data_dim,
|
||||||
|
out_dim=data_dim,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||||
|
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||||
|
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||||
|
if normalize:
|
||||||
|
x = self.normalize(x)
|
||||||
|
moments = self.encoder(x)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||||
|
dec = self.decoder(z)
|
||||||
|
if unnormalize:
|
||||||
|
dec = self.unnormalize(dec)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (x - self.data_mean) / self.data_std
|
||||||
|
|
||||||
|
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * self.data_std + self.data_mean
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sample_posterior: bool = True,
|
||||||
|
rng: Optional[torch.Generator] = None,
|
||||||
|
normalize: bool = True,
|
||||||
|
unnormalize: bool = True,
|
||||||
|
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||||
|
|
||||||
|
posterior = self.encode(x, normalize=normalize)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(rng)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z, unnormalize=unnormalize)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def load_weights(self, src_dict) -> None:
|
||||||
|
self.load_state_dict(src_dict, strict=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for name, m in self.named_modules():
|
||||||
|
if isinstance(m, MPConv1D):
|
||||||
|
m.remove_weight_norm()
|
||||||
|
log.debug(f"Removed weight norm from {name}")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
dim: int,
|
||||||
|
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||||
|
num_res_blocks: int,
|
||||||
|
attn_layers: list[int] = [],
|
||||||
|
down_layers: list[int] = [],
|
||||||
|
resamp_with_conv: bool = True,
|
||||||
|
in_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
double_z: bool = True,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
clip_act: float = 256.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_layers = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.in_channels = in_dim
|
||||||
|
self.clip_act = clip_act
|
||||||
|
self.down_layers = down_layers
|
||||||
|
self.attn_layers = attn_layers
|
||||||
|
self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
|
||||||
|
|
||||||
|
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
# downsampling
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_layers):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = dim * in_ch_mult[i_level]
|
||||||
|
block_out = dim * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_out,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True))
|
||||||
|
block_in = block_out
|
||||||
|
if i_level in attn_layers:
|
||||||
|
attn.append(AttnBlock1D(block_in))
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level in down_layers:
|
||||||
|
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_in,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True)
|
||||||
|
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_in,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.conv_out = MPConv1D(block_in,
|
||||||
|
2 * embed_dim if double_z else embed_dim,
|
||||||
|
kernel_size=kernel_size)
|
||||||
|
|
||||||
|
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
hs = [self.conv_in(x)]
|
||||||
|
for i_level in range(self.num_layers):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](hs[-1])
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
hs.append(h)
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = hs[-1]
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||||
|
num_res_blocks: int,
|
||||||
|
attn_layers: list[int] = [],
|
||||||
|
down_layers: list[int] = [],
|
||||||
|
kernel_size: int = 3,
|
||||||
|
resamp_with_conv: bool = True,
|
||||||
|
in_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
clip_act: float = 256.0):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = dim
|
||||||
|
self.num_layers = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.in_channels = in_dim
|
||||||
|
self.clip_act = clip_act
|
||||||
|
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = dim * ch_mult[self.num_layers - 1]
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||||
|
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_layers)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = dim * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||||
|
block_in = block_out
|
||||||
|
if i_level in attn_layers:
|
||||||
|
attn.append(AttnBlock1D(block_in))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
|
||||||
|
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_layers)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def VAE_16k(**kwargs) -> VAE:
|
||||||
|
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def VAE_44k(**kwargs) -> VAE:
|
||||||
|
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||||
|
if name == '16k':
|
||||||
|
return VAE_16k(**kwargs)
|
||||||
|
if name == '44k':
|
||||||
|
return VAE_44k(**kwargs)
|
||||||
|
raise ValueError(f'Unknown model: {name}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
network = get_my_vae('standard')
|
||||||
|
|
||||||
|
# print the number of parameters in terms of millions
|
||||||
|
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
||||||
|
print(f'Number of parameters: {num_params:.2f}M')
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from selva_core.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
|
||||||
|
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
# swish
|
||||||
|
return mp_silu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResnetBlock1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
out_dim = in_dim if out_dim is None else out_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.use_norm = use_norm
|
||||||
|
|
||||||
|
self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
|
||||||
|
self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
|
||||||
|
if self.in_dim != self.out_dim:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
# pixel norm
|
||||||
|
if self.use_norm:
|
||||||
|
x = normalize(x, dim=1)
|
||||||
|
|
||||||
|
h = x
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_dim != self.out_dim:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return mp_sum(x, h, t=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, num_heads=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
|
||||||
|
self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
y = self.qkv(h)
|
||||||
|
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
|
||||||
|
q, k, v = normalize(y, dim=2).unbind(3)
|
||||||
|
|
||||||
|
q = rearrange(q, 'b h c l -> b h l c')
|
||||||
|
k = rearrange(k, 'b h c l -> b h l c')
|
||||||
|
v = rearrange(v, 'b h c l -> b h l c')
|
||||||
|
|
||||||
|
h = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
h = rearrange(h, 'b h l c -> b (h c) l')
|
||||||
|
|
||||||
|
h = self.proj_out(h)
|
||||||
|
|
||||||
|
return mp_sum(x, h, t=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||||
|
self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2022 NVIDIA CORPORATION.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .bigvgan import BigVGAN
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, sin, pow
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
'''
|
||||||
|
Implementation of a sine-based periodic activation function
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter
|
||||||
|
References:
|
||||||
|
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snake(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
'''
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
'''
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha: trainable parameter
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
'''
|
||||||
|
super(Snake, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
|
'''
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
'''
|
||||||
|
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
References:
|
||||||
|
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snakebeta(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
'''
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
'''
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
'''
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
|
'''
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .filter import *
|
||||||
|
from .resample import *
|
||||||
|
from .act import *
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from .resample import UpSample1d, DownSample1d
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
# x: [B,C,T]
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
if 'sinc' in dir(torch):
|
||||||
|
sinc = torch.sinc
|
||||||
|
else:
|
||||||
|
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/core.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def sinc(x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||||
|
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
return torch.where(x == 0,
|
||||||
|
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
|
even = (kernel_size % 2 == 0)
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
#For kaiser window
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.:
|
||||||
|
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||||
|
else:
|
||||||
|
beta = 0.
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
|
||||||
|
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||||
|
if even:
|
||||||
|
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||||
|
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||||
|
# of the constant component in the input signal.
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = 'replicate',
|
||||||
|
kernel_size: int = 12):
|
||||||
|
# kernel_size should be even number for stylegan3 setup,
|
||||||
|
# in this implementation, odd number is also possible.
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = (kernel_size % 2 == 0)
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
#input [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||||
|
mode=self.padding_mode)
|
||||||
|
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
||||||
|
stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
return out
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from .filter import LowPassFilter1d
|
||||||
|
from .filter import kaiser_sinc_filter1d
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.stride = ratio
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# x: [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
x = x[..., self.pad_left:-self.pad_right]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.lowpass(x)
|
||||||
|
|
||||||
|
return xx
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from selva_core.ext.bigvgan.models import BigVGANVocoder
|
||||||
|
|
||||||
|
_bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
|
||||||
|
|
||||||
|
|
||||||
|
class BigVGAN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
|
||||||
|
super().__init__()
|
||||||
|
vocoder_cfg = OmegaConf.load(config_path)
|
||||||
|
self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
|
||||||
|
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)['generator']
|
||||||
|
self.vocoder.load_state_dict(vocoder_ckpt)
|
||||||
|
|
||||||
|
self.weight_norm_removed = False
|
||||||
|
self.remove_weight_norm()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, x):
|
||||||
|
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
||||||
|
return self.vocoder(x)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
self.vocoder.remove_weight_norm()
|
||||||
|
self.weight_norm_removed = True
|
||||||
|
return self
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
resblock: '1'
|
||||||
|
num_gpus: 0
|
||||||
|
batch_size: 64
|
||||||
|
num_mels: 80
|
||||||
|
learning_rate: 0.0001
|
||||||
|
adam_b1: 0.8
|
||||||
|
adam_b2: 0.99
|
||||||
|
lr_decay: 0.999
|
||||||
|
seed: 1234
|
||||||
|
upsample_rates:
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
upsample_kernel_sizes:
|
||||||
|
- 8
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
upsample_initial_channel: 1536
|
||||||
|
resblock_kernel_sizes:
|
||||||
|
- 3
|
||||||
|
- 7
|
||||||
|
- 11
|
||||||
|
resblock_dilation_sizes:
|
||||||
|
- - 1
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
- - 1
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
- - 1
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
activation: snakebeta
|
||||||
|
snake_logscale: true
|
||||||
|
resolutions:
|
||||||
|
- - 1024
|
||||||
|
- 120
|
||||||
|
- 600
|
||||||
|
- - 2048
|
||||||
|
- 240
|
||||||
|
- 1200
|
||||||
|
- - 512
|
||||||
|
- 50
|
||||||
|
- 240
|
||||||
|
mpd_reshapes:
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
- 7
|
||||||
|
- 11
|
||||||
|
use_spectral_norm: false
|
||||||
|
discriminator_channel_mult: 1
|
||||||
|
num_workers: 4
|
||||||
|
dist_config:
|
||||||
|
dist_backend: nccl
|
||||||
|
dist_url: tcp://localhost:54341
|
||||||
|
world_size: 1
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def build_env(config, config_name, path):
|
||||||
|
t_path = os.path.join(path, config_name)
|
||||||
|
if config != t_path:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
shutil.copyfile(config, os.path.join(path, config_name))
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Jungil Kong
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Edward Dixon
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2019, Seungwon Park 박승원
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
Copyright 2020 Alexandre Défossez
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||||
|
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||||
|
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||||
|
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or
|
||||||
|
substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||||
|
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
@@ -0,0 +1,255 @@
|
|||||||
|
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
|
from selva_core.ext.bigvgan import activations
|
||||||
|
from selva_core.ext.bigvgan.alias_free_torch import *
|
||||||
|
from selva_core.ext.bigvgan.utils import get_padding, init_weights
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||||
|
super(AMPBlock1, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||||
|
|
||||||
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_parametrizations(l, 'weight')
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_parametrizations(l, 'weight')
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock2(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||||
|
super(AMPBlock2, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1])))
|
||||||
|
])
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs) # total number of conv layers
|
||||||
|
|
||||||
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c, a in zip(self.convs, self.activations):
|
||||||
|
xt = a(x)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_parametrizations(l, 'weight')
|
||||||
|
|
||||||
|
|
||||||
|
class BigVGANVocoder(torch.nn.Module):
|
||||||
|
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||||
|
def __init__(self, h):
|
||||||
|
super().__init__()
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
|
# pre conv
|
||||||
|
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||||
|
|
||||||
|
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||||
|
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||||
|
|
||||||
|
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||||
|
h.upsample_initial_channel // (2**(i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2))
|
||||||
|
]))
|
||||||
|
|
||||||
|
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
|
# post conv
|
||||||
|
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||||
|
|
||||||
|
# weight initialization
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
self.ups[i].apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# pre conv
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
# upsampling
|
||||||
|
for i_up in range(len(self.ups[i])):
|
||||||
|
x = self.ups[i][i_up](x)
|
||||||
|
# AMP blocks
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
# post conv
|
||||||
|
x = self.activation_post(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
|
for l_i in l:
|
||||||
|
remove_parametrizations(l_i, 'weight')
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_parametrizations(self.conv_pre, 'weight')
|
||||||
|
remove_parametrizations(self.conv_post, 'weight')
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_weight_norm(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
weight_norm(m)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepath, device):
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
print("Loading '{}'".format(filepath))
|
||||||
|
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||||
|
print("Complete.")
|
||||||
|
return checkpoint_dict
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, sin, pow
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of a sine-based periodic activation function
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter
|
||||||
|
References:
|
||||||
|
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snake(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha: trainable parameter
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
"""
|
||||||
|
super(Snake, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# Initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # Linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
|
"""
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
"""
|
||||||
|
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
References:
|
||||||
|
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snakebeta(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
"""
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# Initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # Linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
|
"""
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||||
|
|
||||||
|
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||||
|
from alias_free_activation.cuda import load
|
||||||
|
|
||||||
|
anti_alias_activation_cuda = load.load()
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAntiAliasActivation(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
||||||
|
The hyperparameters are hard-coded in the kernel to maximize speed.
|
||||||
|
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||||
|
activation_results = anti_alias_activation_cuda.forward(
|
||||||
|
inputs, up_ftr, down_ftr, alpha, beta
|
||||||
|
)
|
||||||
|
|
||||||
|
return activation_results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, output_grads):
|
||||||
|
raise NotImplementedError
|
||||||
|
return output_grads, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
fused: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
self.fused = fused # Whether to use fused CUDA kernel or not
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.fused:
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
if self.act.__class__.__name__ == "Snake":
|
||||||
|
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||||
|
else:
|
||||||
|
beta = (
|
||||||
|
self.act.beta.data
|
||||||
|
) # Snakebeta uses different params for alpha and beta
|
||||||
|
alpha = self.act.alpha.data
|
||||||
|
if (
|
||||||
|
not self.act.alpha_logscale
|
||||||
|
): # Exp baked into cuda kernel, cancel it out with a log
|
||||||
|
alpha = torch.log(alpha)
|
||||||
|
beta = torch.log(beta)
|
||||||
|
|
||||||
|
x = FusedAntiAliasActivation.apply(
|
||||||
|
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||||
|
)
|
||||||
|
return x
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
||||||
|
}
|
||||||
@@ -0,0 +1,246 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_profiler_api.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include "type_shim.h"
|
||||||
|
#include <assert.h>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <limits>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
// Hard-coded hyperparameters
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
constexpr int BUFFER_SIZE = 32;
|
||||||
|
constexpr int FILTER_SIZE = 12;
|
||||||
|
constexpr int HALF_FILTER_SIZE = 6;
|
||||||
|
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
||||||
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
||||||
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
__global__ void anti_alias_activation_forward(
|
||||||
|
output_t *dst,
|
||||||
|
const input_t *src,
|
||||||
|
const input_t *up_ftr,
|
||||||
|
const input_t *down_ftr,
|
||||||
|
const input_t *alpha,
|
||||||
|
const input_t *beta,
|
||||||
|
int batch_size,
|
||||||
|
int channels,
|
||||||
|
int seq_len)
|
||||||
|
{
|
||||||
|
// Up and downsample filters
|
||||||
|
input_t up_filter[FILTER_SIZE];
|
||||||
|
input_t down_filter[FILTER_SIZE];
|
||||||
|
|
||||||
|
// Load data from global memory including extra indices reserved for replication paddings
|
||||||
|
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
||||||
|
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
||||||
|
|
||||||
|
// Output stores downsampled output before writing to dst
|
||||||
|
output_t output[BUFFER_SIZE];
|
||||||
|
|
||||||
|
// blockDim/threadIdx = (128, 1, 1)
|
||||||
|
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
||||||
|
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||||
|
int local_offset = threadIdx.x * BUFFER_SIZE;
|
||||||
|
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
||||||
|
|
||||||
|
// intermediate have double the seq_len
|
||||||
|
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
||||||
|
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
||||||
|
|
||||||
|
// Get values needed for replication padding before moving pointer
|
||||||
|
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||||
|
input_t seq_left_most_value = right_most_pntr[0];
|
||||||
|
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
||||||
|
|
||||||
|
// Move src and dst pointers
|
||||||
|
src += block_offset + local_offset;
|
||||||
|
dst += block_offset + local_offset;
|
||||||
|
|
||||||
|
// Alpha and beta values for snake activatons. Applies exp by default
|
||||||
|
alpha = alpha + blockIdx.y;
|
||||||
|
input_t alpha_val = expf(alpha[0]);
|
||||||
|
beta = beta + blockIdx.y;
|
||||||
|
input_t beta_val = expf(beta[0]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
up_filter[it] = up_ftr[it];
|
||||||
|
down_filter[it] = down_ftr[it];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply replication padding for upsampling, matching torch impl
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
int element_index = seq_offset + it; // index for element
|
||||||
|
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
||||||
|
}
|
||||||
|
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
||||||
|
}
|
||||||
|
if ((element_index >= 0) && (element_index < seq_len))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
||||||
|
{
|
||||||
|
input_t acc = 0.0;
|
||||||
|
int element_index = intermediate_seq_offset + it; // index for intermediate
|
||||||
|
#pragma unroll
|
||||||
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||||
|
{
|
||||||
|
if ((element_index + f_idx) >= 0)
|
||||||
|
{
|
||||||
|
acc += up_filter[f_idx] * elements[it + f_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
||||||
|
double no_div_by_zero = 0.000000001;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply replication padding before downsampling conv from intermediates
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
input_t acc = 0.0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||||
|
{
|
||||||
|
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
||||||
|
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
||||||
|
}
|
||||||
|
output[it] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output to dst
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
||||||
|
{
|
||||||
|
int element_index = seq_offset + it;
|
||||||
|
if (element_index < seq_len)
|
||||||
|
{
|
||||||
|
dst[it] = output[it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_anti_alias_activation_forward(
|
||||||
|
output_t *dst,
|
||||||
|
const input_t *src,
|
||||||
|
const input_t *up_ftr,
|
||||||
|
const input_t *down_ftr,
|
||||||
|
const input_t *alpha,
|
||||||
|
const input_t *beta,
|
||||||
|
int batch_size,
|
||||||
|
int channels,
|
||||||
|
int seq_len)
|
||||||
|
{
|
||||||
|
if (seq_len == 0)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
constexpr int seq_len_per_block = 4096;
|
||||||
|
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
||||||
|
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
||||||
|
dim3 threads(threads_per_block, 1, 1);
|
||||||
|
|
||||||
|
anti_alias_activation_forward<input_t, output_t, acc_t>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
||||||
|
{
|
||||||
|
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
||||||
|
const int batches = input.size(0);
|
||||||
|
const int channels = input.size(1);
|
||||||
|
const int seq_len = input.size(2);
|
||||||
|
|
||||||
|
// Output
|
||||||
|
auto act_options = input.options().requires_grad(false);
|
||||||
|
|
||||||
|
torch::Tensor anti_alias_activation_results =
|
||||||
|
torch::empty({batches, channels, seq_len}, act_options);
|
||||||
|
|
||||||
|
void *input_ptr = static_cast<void *>(input.data_ptr());
|
||||||
|
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
||||||
|
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
||||||
|
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
||||||
|
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
||||||
|
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"dispatch anti alias activation_forward",
|
||||||
|
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
||||||
|
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(input_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(beta_ptr),
|
||||||
|
batches,
|
||||||
|
channels,
|
||||||
|
seq_len););
|
||||||
|
return anti_alias_activation_results;
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*This code is copied fron NVIDIA apex:
|
||||||
|
* https://github.com/NVIDIA/apex
|
||||||
|
* with minor changes. */
|
||||||
|
|
||||||
|
#ifndef TORCH_CHECK
|
||||||
|
#define TORCH_CHECK AT_CHECK
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef VERSION_GE_1_3
|
||||||
|
#define DATA_PTR data_ptr
|
||||||
|
#else
|
||||||
|
#define DATA_PTR data
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
"""
|
||||||
|
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
||||||
|
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
||||||
|
"""
|
||||||
|
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load():
|
||||||
|
# Check if cuda 11 is installed for compute capability 8.0
|
||||||
|
cc_flag = []
|
||||||
|
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||||
|
if int(bare_metal_major) >= 11:
|
||||||
|
cc_flag.append("-gencode")
|
||||||
|
cc_flag.append("arch=compute_80,code=sm_80")
|
||||||
|
|
||||||
|
# Build path
|
||||||
|
srcpath = pathlib.Path(__file__).parent.absolute()
|
||||||
|
buildpath = srcpath / "build"
|
||||||
|
_create_build_dir(buildpath)
|
||||||
|
|
||||||
|
# Helper function to build the kernels.
|
||||||
|
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||||
|
return cpp_extension.load(
|
||||||
|
name=name,
|
||||||
|
sources=sources,
|
||||||
|
build_directory=buildpath,
|
||||||
|
extra_cflags=[
|
||||||
|
"-O3",
|
||||||
|
],
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-O3",
|
||||||
|
"-gencode",
|
||||||
|
"arch=compute_70,code=sm_70",
|
||||||
|
"--use_fast_math",
|
||||||
|
]
|
||||||
|
+ extra_cuda_flags
|
||||||
|
+ cc_flag,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_cuda_flags = [
|
||||||
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||||
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
"--expt-relaxed-constexpr",
|
||||||
|
"--expt-extended-lambda",
|
||||||
|
]
|
||||||
|
|
||||||
|
sources = [
|
||||||
|
srcpath / "anti_alias_activation.cpp",
|
||||||
|
srcpath / "anti_alias_activation_cuda.cu",
|
||||||
|
]
|
||||||
|
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||||
|
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||||
|
)
|
||||||
|
|
||||||
|
return anti_alias_activation_cuda
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cuda_bare_metal_version(cuda_dir):
|
||||||
|
raw_output = subprocess.check_output(
|
||||||
|
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||||
|
)
|
||||||
|
output = raw_output.split()
|
||||||
|
release_idx = output.index("release") + 1
|
||||||
|
release = output[release_idx].split(".")
|
||||||
|
bare_metal_major = release[0]
|
||||||
|
bare_metal_minor = release[1][0]
|
||||||
|
|
||||||
|
return raw_output, bare_metal_major, bare_metal_minor
|
||||||
|
|
||||||
|
|
||||||
|
def _create_build_dir(buildpath):
|
||||||
|
try:
|
||||||
|
os.mkdir(buildpath)
|
||||||
|
except OSError:
|
||||||
|
if not os.path.isdir(buildpath):
|
||||||
|
print(f"Creation of the build directory {buildpath} failed")
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include "compat.h"
|
||||||
|
|
||||||
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
|
switch (TYPE) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||||
|
switch (TYPEIN) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = float; \
|
||||||
|
switch (TYPEOUT) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = at::Half; \
|
||||||
|
using scalar_t_out = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = at::BFloat16; \
|
||||||
|
using scalar_t_out = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||||
|
}
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .filter import *
|
||||||
|
from .resample import *
|
||||||
|
from .act import *
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.resample import (DownSample1d, UpSample1d)
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
# x: [B,C,T]
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
if "sinc" in dir(torch):
|
||||||
|
sinc = torch.sinc
|
||||||
|
else:
|
||||||
|
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/core.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def sinc(x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||||
|
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
return torch.where(
|
||||||
|
x == 0,
|
||||||
|
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def kaiser_sinc_filter1d(
|
||||||
|
cutoff, half_width, kernel_size
|
||||||
|
): # return filter [1,1,kernel_size]
|
||||||
|
even = kernel_size % 2 == 0
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
# For kaiser window
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.0:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.0:
|
||||||
|
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||||
|
else:
|
||||||
|
beta = 0.0
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
|
||||||
|
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||||
|
if even:
|
||||||
|
time = torch.arange(-half_size, half_size) + 0.5
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||||
|
"""
|
||||||
|
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
||||||
|
"""
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = "replicate",
|
||||||
|
kernel_size: int = 12,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = kernel_size % 2 == 0
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# Input [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
|
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
return out
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.filter import (LowPassFilter1d,
|
||||||
|
kaiser_sinc_filter1d)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
|
||||||
|
self.stride = ratio
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = (self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2)
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# x: [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
x = x[..., self.pad_left:-self.pad_right]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
|
||||||
|
self.lowpass = LowPassFilter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.lowpass(x)
|
||||||
|
|
||||||
|
return xx
|
||||||
@@ -0,0 +1,439 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
|
from selva_core.ext.bigvgan_v2 import activations
|
||||||
|
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.act import \
|
||||||
|
Activation1d as TorchActivation1d
|
||||||
|
from selva_core.ext.bigvgan_v2.env import AttrDict
|
||||||
|
from selva_core.ext.bigvgan_v2.utils import get_padding, init_weights
|
||||||
|
|
||||||
|
|
||||||
|
def load_hparams_from_json(path) -> AttrDict:
|
||||||
|
with open(path) as f:
|
||||||
|
data = f.read()
|
||||||
|
return AttrDict(json.loads(data))
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||||
|
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
channels (int): Number of convolution channels.
|
||||||
|
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||||
|
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||||
|
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h: AttrDict,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: tuple = (1, 3, 5),
|
||||||
|
activation: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=get_padding(kernel_size, d),
|
||||||
|
)) for d in dilation
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
)) for _ in range(len(dilation))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from alias_free_activation.cuda.activation1d import \
|
||||||
|
Activation1d as CudaActivation1d
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
# Activation functions
|
||||||
|
if activation == "snake":
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == "snakebeta":
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_parametrizations(l, 'weight')
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_parametrizations(l, 'weight')
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock2(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||||
|
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
channels (int): Number of convolution channels.
|
||||||
|
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||||
|
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||||
|
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h: AttrDict,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: tuple = (1, 3, 5),
|
||||||
|
activation: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=get_padding(kernel_size, d),
|
||||||
|
)) for d in dilation
|
||||||
|
])
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs) # Total number of conv layers
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from alias_free_activation.cuda.activation1d import \
|
||||||
|
Activation1d as CudaActivation1d
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
# Activation functions
|
||||||
|
if activation == "snake":
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == "snakebeta":
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c, a in zip(self.convs, self.activations):
|
||||||
|
xt = a(x)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class BigVGAN(
|
||||||
|
torch.nn.Module,
|
||||||
|
PyTorchModelHubMixin,
|
||||||
|
library_name="bigvgan",
|
||||||
|
repo_url="https://github.com/NVIDIA/BigVGAN",
|
||||||
|
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
||||||
|
pipeline_tag="audio-to-audio",
|
||||||
|
license="mit",
|
||||||
|
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
||||||
|
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
||||||
|
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.h = h
|
||||||
|
self.h["use_cuda_kernel"] = use_cuda_kernel
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from alias_free_activation.cuda.activation1d import \
|
||||||
|
Activation1d as CudaActivation1d
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
|
# Pre-conv
|
||||||
|
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||||
|
|
||||||
|
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||||
|
if h.resblock == "1":
|
||||||
|
resblock_class = AMPBlock1
|
||||||
|
elif h.resblock == "2":
|
||||||
|
resblock_class = AMPBlock2
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||||
|
|
||||||
|
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(
|
||||||
|
h.upsample_initial_channel // (2**i),
|
||||||
|
h.upsample_initial_channel // (2**(i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
))
|
||||||
|
]))
|
||||||
|
|
||||||
|
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
|
# Post-conv
|
||||||
|
activation_post = (activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
if h.activation == "snake" else
|
||||||
|
(activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
if h.activation == "snakebeta" else None))
|
||||||
|
if activation_post is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
|
||||||
|
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||||
|
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||||
|
|
||||||
|
# Weight initialization
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
self.ups[i].apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
|
||||||
|
# Final tanh activation. Defaults to True for backward compatibility
|
||||||
|
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Pre-conv
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
# Upsampling
|
||||||
|
for i_up in range(len(self.ups[i])):
|
||||||
|
x = self.ups[i][i_up](x)
|
||||||
|
# AMP blocks
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
# Post-conv
|
||||||
|
x = self.activation_post(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
# Final tanh activation
|
||||||
|
if self.use_tanh_at_final:
|
||||||
|
x = torch.tanh(x)
|
||||||
|
else:
|
||||||
|
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
try:
|
||||||
|
print("Removing weight norm...")
|
||||||
|
for l in self.ups:
|
||||||
|
for l_i in l:
|
||||||
|
remove_parametrizations(l_i, 'weight')
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_parametrizations(self.conv_pre, 'weight')
|
||||||
|
remove_parametrizations(self.conv_post, 'weight')
|
||||||
|
except ValueError:
|
||||||
|
print("[INFO] Model already removed weight norm. Skipping!")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Additional methods for huggingface_hub support
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
"""Save weights and config.json from a Pytorch model to a local directory."""
|
||||||
|
|
||||||
|
model_path = save_directory / "bigvgan_generator.pt"
|
||||||
|
torch.save({"generator": self.state_dict()}, model_path)
|
||||||
|
|
||||||
|
config_path = save_directory / "config.json"
|
||||||
|
with open(config_path, "w") as config_file:
|
||||||
|
json.dump(self.h, config_file, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_pretrained(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
|
cache_dir: str,
|
||||||
|
force_download: bool,
|
||||||
|
proxies: Optional[Dict] = None,
|
||||||
|
resume_download: bool = False,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
token: Union[str, bool, None] = None,
|
||||||
|
map_location: str = "cpu", # Additional argument
|
||||||
|
strict: bool = False, # Additional argument
|
||||||
|
use_cuda_kernel: bool = False,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||||
|
|
||||||
|
# Download and load hyperparameters (h) used by BigVGAN
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
print("Loading config.json from local directory")
|
||||||
|
config_file = os.path.join(model_id, "config.json")
|
||||||
|
else:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="config.json",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
h = load_hparams_from_json(config_file)
|
||||||
|
|
||||||
|
# instantiate BigVGAN using h
|
||||||
|
if use_cuda_kernel:
|
||||||
|
print(
|
||||||
|
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||||
|
)
|
||||||
|
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||||
|
|
||||||
|
# Download and load pretrained generator weight
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
print("Loading weights from local directory")
|
||||||
|
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
||||||
|
else:
|
||||||
|
print(f"Loading weights from {model_id}")
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="bigvgan_generator.pt",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||||
|
)
|
||||||
|
model.remove_weight_norm()
|
||||||
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def build_env(config, config_name, path):
|
||||||
|
t_path = os.path.join(path, config_name)
|
||||||
|
if config != t_path:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
shutil.copyfile(config, os.path.join(path, config_name))
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Jungil Kong
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Edward Dixon
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2019, Seungwon Park 박승원
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user