Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c9550ce693 | |||
| f3cabcad90 | |||
| b519b042e2 | |||
| f28759f1e3 | |||
| 3dd6badfd9 | |||
| 8bb2fb7015 | |||
| f4a7292cde | |||
| bd53744e2d | |||
| 429810db5b | |||
| 57f56c04e2 | |||
| ff26d0b87d | |||
| 83b1da9520 | |||
| 679a607a85 | |||
| d495939367 | |||
| 982d66e078 | |||
| b4124f58b3 | |||
| 2c9d521565 | |||
| 28229d62ce | |||
| 92593189f0 | |||
| 614a2e02aa | |||
| 40388ba6de | |||
| 789e09535d | |||
| 4da4858e4a | |||
| ab8e1e5b7b | |||
| e3a3384727 | |||
| 9a985499e7 | |||
| 27b4424e1a | |||
| 0e417f4078 | |||
| 6474e2816c | |||
| c23d210ab2 | |||
| b59b657b6f | |||
| 578b501d38 | |||
| fe94438356 | |||
| 6bc3fd6443 |
@@ -1,156 +1,134 @@
|
||||
# ComfyUI-PrismAudio
|
||||
# ComfyUI-SelVA
|
||||
|
||||
Custom nodes for [PrismAudio](https://huggingface.co/FunAudioLLM/PrismAudio) (ICLR 2026) — video-to-audio and text-to-audio generation using decomposed Chain-of-Thought reasoning with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
||||
Custom nodes for [SelVA](https://github.com/jnwnlee/selva) — video-to-audio generation driven by text prompts. SelVA conditions audio synthesis on both visual content and natural language, letting you describe *what* sounds to generate rather than just *when*.
|
||||
|
||||
## Installation
|
||||
Built on [MMAudio](https://github.com/hkchengrex/MMAudio) with a TextSynchformer encoder that injects text guidance directly into the visual sync stream.
|
||||
|
||||
Clone into your ComfyUI custom nodes directory:
|
||||
|
||||
```bash
|
||||
cd ComfyUI/custom_nodes
|
||||
git clone https://github.com/Ethanfel/ComfyUI-Prismaudio.git ComfyUI-PrismAudio
|
||||
pip install -r ComfyUI-PrismAudio/requirements.txt
|
||||
```
|
||||
|
||||
**flash-attn** is optional — detected at runtime, falls back to PyTorch SDPA if unavailable.
|
||||
---
|
||||
|
||||
## Nodes
|
||||
|
||||
### PrismAudio Model Loader
|
||||
### SelVA Model Loader
|
||||
|
||||
Loads the DiT diffusion model and VAE. Auto-downloads weights from HuggingFace on first use.
|
||||
Loads the generator, TextSynchformer encoder, and all feature utilities (CLIP, T5, Synchformer, VAE). Weights are auto-downloaded from HuggingFace on first use.
|
||||
|
||||
| Input | Options | Description |
|
||||
|-------|---------|-------------|
|
||||
| `precision` | auto / fp32 / fp16 / bf16 | DiT and conditioner dtype. VAE is always fp32. |
|
||||
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management. |
|
||||
| `variant` | small_16k / small_44k / medium_44k / large_44k | Model size and output sample rate |
|
||||
| `precision` | bf16 / fp16 / fp32 | Compute dtype |
|
||||
| `offload_strategy` | auto / keep_in_vram / offload_to_cpu | Memory management |
|
||||
|
||||
**Output:** `model` (SELVA_MODEL)
|
||||
|
||||
---
|
||||
|
||||
### PrismAudio Feature Extractor
|
||||
### SelVA Feature Extractor
|
||||
|
||||
Extracts video features (VideoPrism LvT, Synchformer) and text features (T5-Gemma) from a video in a subprocess. Results are cached on disk.
|
||||
Extracts CLIP visual features and text-guided sync features from a video. Results are cached on disk — re-running with the same inputs is instant.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| `model` | From SelVA Model Loader |
|
||||
| `video` | IMAGE tensor from any ComfyUI video loader |
|
||||
| `caption_cot` | Chain-of-thought description of the audio scene |
|
||||
| `video_info` | *(optional)* `VHS_VIDEOINFO` from VHS LoadVideo — sets fps automatically |
|
||||
| `prompt` | Text description of the audio to generate |
|
||||
| `video_info` | *(optional)* VHS_VIDEOINFO from VHS LoadVideo — sets fps automatically |
|
||||
| `fps` | Source fps — ignored if `video_info` is connected |
|
||||
| `python_env` | `managed_env` (auto-created isolated venv, recommended) or `comfyui_env` (current Python, see warning below) |
|
||||
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir. |
|
||||
| `hf_token` | HuggingFace token for gated models. Prefer `HF_TOKEN` env var instead. |
|
||||
| `duration` | Override clip duration in seconds. `0` = infer from video length |
|
||||
| `cache_dir` | Directory for cached `.npz` files. Empty = system temp dir |
|
||||
| `mask` | *(optional)* Segmentation mask `[T,H,W]` float [0,1] — static (1 frame) or per-frame |
|
||||
| `mask_strength` | Background suppression strength. `1.0` = full neutral fill, `0.0` = no effect |
|
||||
| `mask_clip` | Apply mask to CLIP features (384px path). Disable to let CLIP see the full scene |
|
||||
| `mask_sync` | Apply mask to TextSynchformer sync features (224px path) |
|
||||
|
||||
**Outputs:** `features` (PRISMAUDIO_FEATURES), `fps` (FLOAT)
|
||||
**Outputs:** `features` (SELVA_FEATURES), `fps` (FLOAT), `prompt` (STRING)
|
||||
|
||||
**`managed_env`** auto-creates a venv at `_extract_env/` inside the plugin directory on first use and installs JAX, TF, VideoPrism, and Synchformer. This takes several minutes the first time.
|
||||
Connect `prompt` output to the Sampler's `prompt` input to avoid entering it twice.
|
||||
|
||||
**`comfyui_env`** uses the current ComfyUI Python — JAX/TF/videoprism must already be installed. Installing them into the ComfyUI environment may conflict with existing packages.
|
||||
#### Masking
|
||||
|
||||
Connect a segmentation mask (SAM2, Grounding DINO+SAM, or any ComfyUI mask node) to isolate a specific object's motion before encoding. Background pixels are filled with a neutral value (0.5) rather than zeroed — this keeps them in-distribution for CLIP and maps to exactly 0 after sync's `[-1,1]` normalization, minimising the influence of background motion on the generated audio.
|
||||
|
||||
Use `mask_sync=true, mask_clip=false` if you want sync features focused on the target object while CLIP still sees the full scene for broader context. Changing any mask parameter correctly busts the feature cache.
|
||||
|
||||
---
|
||||
|
||||
### PrismAudio Feature Loader
|
||||
### SelVA Sampler
|
||||
|
||||
Loads a pre-computed `.npz` feature file. Use this to re-use extracted features without re-running the extractor.
|
||||
Generates audio from video features. Runs the rectified flow ODE with classifier-free guidance.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| `npz_path` | Path to a `.npz` file produced by the Feature Extractor |
|
||||
|
||||
---
|
||||
|
||||
### PrismAudio Sampler
|
||||
|
||||
Video-to-audio generation. Takes model + features, produces AUDIO.
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| `model` | From Model Loader |
|
||||
| `features` | From Feature Extractor or Feature Loader |
|
||||
| `duration` | Audio duration in seconds. Set to `0` to use the video duration from features automatically. |
|
||||
| `steps` | Sampling steps (default: 100) |
|
||||
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
||||
| `model` | From SelVA Model Loader |
|
||||
| `features` | From SelVA Feature Extractor |
|
||||
| `prompt` | Text description — leave empty to use the prompt stored in features |
|
||||
| `negative_prompt` | What to suppress (e.g. `"speech, voice, talking"`) |
|
||||
| `duration` | Audio duration in seconds. `0` = use duration from features |
|
||||
| `steps` | Sampling steps (default: 25) |
|
||||
| `cfg_strength` | Classifier-free guidance scale (default: 4.5) |
|
||||
| `seed` | RNG seed |
|
||||
| `normalize` | Peak-normalize output to [-1, 1] (default: true) |
|
||||
|
||||
**Output:** `AUDIO`
|
||||
|
||||
---
|
||||
|
||||
### PrismAudio Text Only
|
||||
## Workflow
|
||||
|
||||
Text-to-audio generation without video. Uses the T5-Gemma encoder.
|
||||
```
|
||||
VHS LoadVideo ──► SelVA Feature Extractor ──────────────────────► SelVA Sampler ──► Save Audio
|
||||
│ (video_info) ─► (fps auto) ▲
|
||||
│ (features) ────────────────────────────────────►│
|
||||
│ (prompt) ──────────────────────────────────────►│
|
||||
```
|
||||
|
||||
| Input | Description |
|
||||
|-------|-------------|
|
||||
| `model` | From Model Loader |
|
||||
| `text_prompt` | Chain-of-thought audio scene description. Longer, more detailed prompts produce better results. |
|
||||
| `duration` | Audio duration in seconds |
|
||||
| `steps` | Sampling steps (default: 100) |
|
||||
| `cfg_scale` | Classifier-free guidance scale (default: 7.0) |
|
||||
| `seed` | RNG seed |
|
||||
Connect the `prompt` output of Feature Extractor directly to Sampler's `prompt` to keep them in sync. Leave Sampler's `prompt` empty and it will use whatever was stored during extraction.
|
||||
|
||||
---
|
||||
|
||||
## Workflows
|
||||
## Installation
|
||||
|
||||
### Video-to-Audio
|
||||
|
||||
```
|
||||
VHS LoadVideo ──► PrismAudio Feature Extractor ──► PrismAudio Sampler ──► Save Audio
|
||||
(video_info) ──────────────────► (fps auto)
|
||||
(features) ────────────────────► (features)
|
||||
duration=0 ─────────────────────► (auto from features)
|
||||
```bash
|
||||
cd ComfyUI/custom_nodes
|
||||
git clone https://github.com/Ethanfel/ComfyUI-SelVA.git
|
||||
pip install -r ComfyUI-SelVA/requirements.txt
|
||||
```
|
||||
|
||||
### Pre-computed Features
|
||||
---
|
||||
|
||||
```
|
||||
PrismAudio Feature Loader (.npz) ──► PrismAudio Sampler ──► Save Audio
|
||||
```
|
||||
## Model Weights
|
||||
|
||||
### Text-to-Audio
|
||||
|
||||
```
|
||||
PrismAudio Text Only ──► Save Audio
|
||||
```
|
||||
|
||||
## HuggingFace Authentication
|
||||
|
||||
Required for T5-Gemma (gated model) and PrismAudio weights.
|
||||
|
||||
1. Visit <https://huggingface.co/FunAudioLLM/PrismAudio> and accept the license.
|
||||
2. Authenticate via one of:
|
||||
- **Environment variable:** `export HF_TOKEN=hf_...`
|
||||
- **CLI login:** `huggingface-cli login`
|
||||
|
||||
There is no `hf_token` widget on the main nodes by design — ComfyUI saves all STRING values to workflow JSON, which would expose your token. The Feature Extractor has an `hf_token` input as a convenience but using `HF_TOKEN` env var is preferred.
|
||||
|
||||
## Model Files
|
||||
|
||||
Weights are auto-downloaded to `ComfyUI/models/prismaudio/`:
|
||||
Weights are auto-downloaded to `ComfyUI/models/selva/` on first load. No manual setup required.
|
||||
|
||||
| File | Size | Description |
|
||||
|------|------|-------------|
|
||||
| `prismaudio.ckpt` | ~2.7 GB | Diffusion model (DiT) |
|
||||
| `vae.ckpt` | ~2.5 GB | Stable Audio 2.0 VAE |
|
||||
| `synchformer_state_dict.pth` | ~950 MB | Synchformer visual encoder |
|
||||
| `video_enc_sup_5.pth` | ~300 MB | TextSynchformer encoder |
|
||||
| `generator_small_16k_sup_5.pth` | ~340 MB | Small generator, 16 kHz output |
|
||||
| `generator_small_44k_sup_5.pth` | ~340 MB | Small generator, 44.1 kHz output |
|
||||
| `generator_medium_44k_sup_5.pth` | ~860 MB | Medium generator, 44.1 kHz output |
|
||||
| `generator_large_44k_sup_5.pth` | ~2.0 GB | Large generator, 44.1 kHz output |
|
||||
| `v1-16.pth` | ~1.1 GB | VAE for 16 kHz |
|
||||
| `v1-44.pth` | ~1.1 GB | VAE for 44.1 kHz |
|
||||
| `best_netG.pt` | ~90 MB | BigVGAN vocoder for 16 kHz |
|
||||
| `synchformer_state_dict.pth` | ~950 MB | Synchformer (shared with PrismAudio if present) |
|
||||
|
||||
T5-Gemma and VideoPrism LvT are cached in `~/.cache/huggingface/`.
|
||||
CLIP (DFN5B-ViT-H-14-384) and T5 (flan-t5-base) are downloaded automatically from HuggingFace to `~/.cache/huggingface/`.
|
||||
|
||||
---
|
||||
|
||||
## VRAM Requirements
|
||||
|
||||
| VRAM | Recommended settings |
|
||||
|------|----------------------|
|
||||
| 24 GB+ | `keep_in_vram`, any precision |
|
||||
| 12–24 GB | `offload_to_cpu`, bf16/fp16 |
|
||||
| 8–12 GB | `offload_to_cpu`, fp16 |
|
||||
| < 8 GB | May work with `offload_to_cpu` + fp16 |
|
||||
| 24 GB+ | `keep_in_vram`, any variant |
|
||||
| 12–24 GB | `offload_to_cpu`, medium or smaller |
|
||||
| 8–12 GB | `offload_to_cpu`, small variant, fp16 |
|
||||
|
||||
## Troubleshooting
|
||||
The `auto` offload strategy picks `keep_in_vram` if ≥ 16 GB VRAM is available, otherwise `offload_to_cpu`.
|
||||
|
||||
- **Gated model errors** — Accept the license at <https://huggingface.co/FunAudioLLM/PrismAudio> and set `HF_TOKEN`.
|
||||
- **VRAM errors** — Switch `offload_strategy` to `offload_to_cpu` and/or use `fp16` precision.
|
||||
- **Feature extraction fails** — Ensure `synchformer_state_dict.pth` is in `models/prismaudio/`. On first run with `managed_env`, installation takes several minutes.
|
||||
- **flash-attn** — Optional. Auto-detected at runtime; falls back to PyTorch SDPA.
|
||||
---
|
||||
|
||||
## Credits
|
||||
|
||||
PrismAudio by [FunAudioLLM](https://github.com/FunAudioLLM) (ICLR 2026). [Model & weights](https://huggingface.co/FunAudioLLM/PrismAudio).
|
||||
- [SelVA](https://github.com/jnwnlee/selva) by Jaehwan Lee et al. — TextSynchformer and SelVA training
|
||||
- [MMAudio](https://github.com/hkchengrex/MMAudio) by Feng et al. — MM-DiT audio generator and flow matching framework
|
||||
- [BigVGAN](https://github.com/NVIDIA/BigVGAN) by NVIDIA — neural vocoder for 16 kHz synthesis
|
||||
|
||||
+1
-1
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
|
||||
ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
"""
|
||||
PrismAudio feature extraction utilities.
|
||||
|
||||
Implements FeaturesUtils used by scripts/extract_features.py to extract:
|
||||
- Text features via T5-Gemma (transformers)
|
||||
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
|
||||
- Sync features via Synchformer visual encoder (PyTorch)
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FeaturesUtils:
|
||||
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
|
||||
self.device = device or torch.device("cpu")
|
||||
self._t5_tokenizer = None
|
||||
self._t5_encoder = None
|
||||
self._vp_model = None
|
||||
self._vp_state = None
|
||||
self._vp_text_tokenizer = None
|
||||
self._sync_model = None
|
||||
|
||||
self._synchformer_ckpt = synchformer_ckpt
|
||||
self._load_synchformer()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# T5-Gemma text encoding
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_t5(self):
|
||||
if self._t5_encoder is not None:
|
||||
return
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
model_id = "google/t5gemma-l-l-ul2-it"
|
||||
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
|
||||
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self._t5_encoder = (
|
||||
AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
.get_encoder()
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
def encode_t5_text(self, texts):
|
||||
"""
|
||||
Args:
|
||||
texts: list of str
|
||||
Returns:
|
||||
Tensor [seq_len, 1024]
|
||||
"""
|
||||
self._ensure_t5()
|
||||
tokens = self._t5_tokenizer(
|
||||
texts, return_tensors="pt", padding=True
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
out = self._t5_encoder(**tokens)
|
||||
# Move encoder off GPU to save VRAM
|
||||
self._t5_encoder.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VideoPrism video + text encoding (JAX)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_videoprism(self):
|
||||
if self._vp_model is not None:
|
||||
return
|
||||
from videoprism import models as vp
|
||||
import jax
|
||||
model_name = "videoprism_lvt_public_v1_large"
|
||||
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
|
||||
self._vp_model = vp.get_model(model_name)
|
||||
self._vp_state = vp.load_pretrained_weights(model_name)
|
||||
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
|
||||
jax_dev = jax.devices()[0]
|
||||
self._jax_forward = jax.jit(
|
||||
lambda x, y, z: self._vp_model.apply(
|
||||
self._vp_state, x, y, z, train=False, return_intermediate=True
|
||||
),
|
||||
device=jax_dev,
|
||||
)
|
||||
|
||||
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
||||
"""
|
||||
Args:
|
||||
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
|
||||
Returns:
|
||||
global_video_features: Tensor [1, D]
|
||||
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
|
||||
global_text_features: Tensor [1, D]
|
||||
"""
|
||||
self._ensure_videoprism()
|
||||
import jax.numpy as jnp
|
||||
from videoprism import models as vp
|
||||
|
||||
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
||||
frames = clip_input.squeeze(0) # [T, C, H, W]
|
||||
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
|
||||
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
||||
frames_np = frames.cpu().numpy().astype(np.float32)
|
||||
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
||||
|
||||
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
|
||||
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
|
||||
|
||||
# Joint video+text forward with intermediate outputs
|
||||
video_embeddings, text_embeddings, outputs = self._jax_forward(
|
||||
frames_jax, text_ids, text_paddings
|
||||
)
|
||||
|
||||
# Per-frame features: [B, T, 1024] L2-normalized
|
||||
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
|
||||
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
|
||||
|
||||
# Global video embedding: [1024] → [1, 1024]
|
||||
global_video = torch.from_numpy(
|
||||
np.array(video_embeddings[0])
|
||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||
|
||||
# Global text embedding: [1024] → [1, 1024]
|
||||
global_text = torch.from_numpy(
|
||||
np.array(text_embeddings[0])
|
||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||
|
||||
return global_video, per_frame, global_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synchformer sync feature encoding
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_synchformer(self):
|
||||
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
|
||||
return
|
||||
|
||||
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
|
||||
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
|
||||
|
||||
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
|
||||
if isinstance(state, dict) and "model" in state:
|
||||
state_dict = state["model"]
|
||||
else:
|
||||
state_dict = state
|
||||
|
||||
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
|
||||
self._sync_model.eval()
|
||||
|
||||
def encode_video_with_sync(self, sync_input):
|
||||
"""
|
||||
Args:
|
||||
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||
Returns:
|
||||
sync_features: Tensor [num_segments, 768]
|
||||
"""
|
||||
if self._sync_model is None:
|
||||
raise RuntimeError(
|
||||
"[FeaturesUtils] Synchformer checkpoint not loaded. "
|
||||
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
|
||||
)
|
||||
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
|
||||
with torch.no_grad():
|
||||
return self._sync_model(frames)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synchformer visual encoder — TimeSformer-style ViT-B/16
|
||||
# Architecture reverse-engineered from synchformer_state_dict.pth
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class _PatchEmbed(nn.Module):
|
||||
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class _ViTAttn(nn.Module):
|
||||
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, D = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
|
||||
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
|
||||
|
||||
|
||||
class _BlockMLP(nn.Module):
|
||||
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
|
||||
def __init__(self, dim=768, mlp_dim=3072):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, mlp_dim)
|
||||
self.fc2 = nn.Linear(mlp_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _TimeSformerBlock(nn.Module):
|
||||
"""
|
||||
Factorized space-time attention block.
|
||||
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
|
||||
"""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn = _ViTAttn(dim, num_heads)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.timeattn = _ViTAttn(dim, num_heads)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.mlp = _BlockMLP(dim)
|
||||
|
||||
def forward(self, x, T):
|
||||
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
|
||||
x = x + self.attn(self.norm1(x))
|
||||
# Temporal attention: for each spatial position, attend across T frames
|
||||
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
|
||||
xt = x.permute(1, 0, 2)
|
||||
xt = xt + self.timeattn(self.norm3(xt))
|
||||
x = xt.permute(1, 0, 2)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _SpatialAttnAgg(nn.Module):
|
||||
"""
|
||||
Aggregates 196 spatial patches → 1 feature per frame using a
|
||||
TransformerEncoderLayer with a learnable CLS token.
|
||||
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
|
||||
"""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
||||
self.linear1 = nn.Linear(dim, dim * 4)
|
||||
self.linear2 = nn.Linear(dim * 4, dim)
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [T, 196, 768] — spatial patches (CLS stripped)
|
||||
T = x.shape[0]
|
||||
cls = self.cls_token.expand(T, -1, -1)
|
||||
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
|
||||
xn = self.norm1(x)
|
||||
x = x + self.self_attn(xn, xn, xn)[0]
|
||||
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
||||
return x[:, 0, :] # [T, 768] — CLS per frame
|
||||
|
||||
|
||||
class _SynchformerVisualEncoder(nn.Module):
|
||||
"""
|
||||
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
|
||||
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.segment_frames = 8
|
||||
|
||||
self.patch_embed = _PatchEmbed()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
|
||||
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
|
||||
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
|
||||
self.norm = nn.LayerNorm(768)
|
||||
self.spatial_attn_agg = _SpatialAttnAgg()
|
||||
|
||||
# Load weights from vfeat_extractor.* prefix
|
||||
prefix = "vfeat_extractor."
|
||||
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
# Exclude 3D patch embed (we use 2D only)
|
||||
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
|
||||
missing, unexpected = self.load_state_dict(sub, strict=False)
|
||||
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
|
||||
if missing:
|
||||
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
|
||||
|
||||
self.to(device)
|
||||
|
||||
def forward(self, frames):
|
||||
"""
|
||||
Args:
|
||||
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
|
||||
Returns:
|
||||
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
seg = self.segment_frames
|
||||
num_seg = max(1, T // seg)
|
||||
T_aligned = num_seg * seg
|
||||
|
||||
results = []
|
||||
for i in range(num_seg):
|
||||
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
|
||||
results.append(self._forward_segment(chunk))
|
||||
return torch.cat(results, dim=0) # [T_aligned, 768]
|
||||
|
||||
def _forward_segment(self, x):
|
||||
# x: [8, 3, 224, 224]
|
||||
T = x.shape[0] # 8
|
||||
|
||||
# Patch embedding + CLS token
|
||||
x = self.patch_embed(x) # [8, 196, 768]
|
||||
cls = self.cls_token.expand(T, -1, -1)
|
||||
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
|
||||
|
||||
# Positional + temporal embeddings
|
||||
x = x + self.pos_embed # broadcast (1,197,768)
|
||||
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
|
||||
|
||||
# Transformer blocks (factorized space-time)
|
||||
for block in self.blocks:
|
||||
x = block(x, T)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# Aggregate spatial patches → 1 feature per frame
|
||||
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
|
||||
@@ -1,194 +0,0 @@
|
||||
# ComfyUI-PrismAudio Design Document
|
||||
|
||||
**Date:** 2026-03-27
|
||||
**Status:** Approved
|
||||
|
||||
## Overview
|
||||
|
||||
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
||||
|
||||
## Architecture
|
||||
|
||||
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
ComfyUI-PrismAudio/
|
||||
├── __init__.py # Node registration
|
||||
├── nodes/
|
||||
│ ├── __init__.py
|
||||
│ ├── model_loader.py # PrismAudioModelLoader
|
||||
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
|
||||
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
|
||||
│ ├── sampler.py # PrismAudioSampler
|
||||
│ ├── text_only.py # PrismAudioTextOnly
|
||||
│ └── utils.py # Shared helpers
|
||||
├── prismaudio_core/ # Extracted inference code from PrismAudio
|
||||
│ ├── __init__.py
|
||||
│ ├── configs/
|
||||
│ │ └── prismaudio.json
|
||||
│ ├── models/ # DiT, conditioners, autoencoders, etc.
|
||||
│ ├── inference/ # sampling.py, generation.py
|
||||
│ └── factory.py # create_model_from_config
|
||||
├── scripts/
|
||||
│ ├── extract_features.py # Standalone VideoPrism feature extraction
|
||||
│ └── environment.yml # Conda env for extraction (JAX + TF)
|
||||
├── requirements.txt # PyTorch-only deps (no JAX/TF)
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
### PrismAudioModelLoader
|
||||
|
||||
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
|
||||
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
|
||||
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
|
||||
| **Output** | | |
|
||||
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
|
||||
|
||||
**Token resolution order** (no widget — env/CLI only for security):
|
||||
1. `HF_TOKEN` environment variable
|
||||
2. `huggingface-cli login` cached token
|
||||
3. None — fails on gated models with clear error message linking to license page
|
||||
|
||||
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
|
||||
|
||||
### PrismAudioFeatureLoader
|
||||
|
||||
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| npz_path | STRING | Path to .npz file |
|
||||
| **Output** | | |
|
||||
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
|
||||
|
||||
### PrismAudioFeatureExtractor
|
||||
|
||||
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| video | IMAGE | ComfyUI video frames tensor |
|
||||
| caption_cot | STRING | CoT description text |
|
||||
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
|
||||
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
|
||||
| **Output** | | |
|
||||
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
|
||||
|
||||
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
|
||||
|
||||
### PrismAudioSampler
|
||||
|
||||
Main generation node — takes model + features, produces audio.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
|
||||
| cot_description | STRING | Multiline CoT text |
|
||||
| duration | FLOAT | 1.0-30.0, defaults to video length |
|
||||
| steps | INT | 1-100, default 24 |
|
||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||
| seed | INT | Controls noise generation |
|
||||
| **Output** | | |
|
||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||
|
||||
**Pipeline:**
|
||||
1. Encode CoT text via T5-Gemma -> text_features
|
||||
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
|
||||
3. Compute latent_seq_len = round(44100 / 2048 * duration)
|
||||
4. Generate noise [1, 64, latent_seq_len] from seed
|
||||
5. Discrete Euler sampling (rectified flow) with CFG
|
||||
6. VAE decode -> stereo waveform at 44100 Hz
|
||||
7. Normalize to [-1, 1], return as AUDIO
|
||||
|
||||
### PrismAudioTextOnly
|
||||
|
||||
Text-to-audio without video input.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||
| text_prompt | STRING | Text description |
|
||||
| duration | FLOAT | 1.0-30.0 |
|
||||
| steps | INT | 1-100, default 24 |
|
||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||
| seed | INT | Controls noise generation |
|
||||
| **Output** | | |
|
||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||
|
||||
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
|
||||
|
||||
## VRAM Management
|
||||
|
||||
Adaptive strategy using `comfy.model_management`:
|
||||
|
||||
| Available VRAM | Behavior |
|
||||
|---|---|
|
||||
| 24GB+ | Keep diffusion + VAE in VRAM |
|
||||
| 12-24GB | Sequential offload between stages |
|
||||
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
|
||||
| <8GB | Warn user, attempt with aggressive offload + fp16 |
|
||||
|
||||
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
|
||||
|
||||
## Feature Extraction Paths
|
||||
|
||||
### Path 1: Pre-computed .npz (FeatureLoader)
|
||||
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
|
||||
|
||||
### Path 2: Subprocess bridge (FeatureExtractor)
|
||||
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
|
||||
|
||||
### Path 3: Text-only (TextOnly node)
|
||||
No video features needed. T5-Gemma text encoding only (PyTorch-native).
|
||||
|
||||
## Dependencies
|
||||
|
||||
### ComfyUI environment (`requirements.txt`)
|
||||
```
|
||||
einops>=0.7.0
|
||||
safetensors
|
||||
huggingface_hub
|
||||
transformers>=4.52.3
|
||||
k-diffusion>=0.1.1
|
||||
```
|
||||
|
||||
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
|
||||
|
||||
### Extraction environment (`scripts/environment.yml`)
|
||||
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
|
||||
|
||||
## Model Files
|
||||
|
||||
Stored in `ComfyUI/models/prismaudio/`:
|
||||
|
||||
| File | Size | Source |
|
||||
|------|------|--------|
|
||||
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
|
||||
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
|
||||
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
|
||||
|
||||
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
|
||||
|
||||
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
|
||||
|
||||
## Design Decisions
|
||||
|
||||
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
|
||||
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
|
||||
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
|
||||
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
|
||||
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,167 +0,0 @@
|
||||
# SelVA Integration Design
|
||||
|
||||
**Date:** 2026-04-04
|
||||
**Branch:** feature/selva-integration (new from master)
|
||||
**Status:** Approved, ready for implementation
|
||||
|
||||
---
|
||||
|
||||
## Problem
|
||||
|
||||
PrismAudio's sync conditioning is text-agnostic: Synchformer extracts features from
|
||||
all visual motion equally. In multi-source videos (person walking near a car), the DiT
|
||||
receives unfocused sync guidance and struggles to match audio events to the correct
|
||||
visual source.
|
||||
|
||||
SelVA (CVPR 2026, arXiv:2512.02650) solves this with TextSynchformer — text conditioning
|
||||
is injected inside the Synchformer encoder via cross-attention, so sync features only
|
||||
encode motion relevant to the requested sound. This is the core architectural improvement
|
||||
needed for reliable V2A sync.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### New directory layout
|
||||
|
||||
```
|
||||
selva_core/ ← vendored SelVA source (model + ext + utils)
|
||||
nodes/
|
||||
selva_model_loader.py
|
||||
selva_feature_extractor.py
|
||||
selva_sampler.py
|
||||
```
|
||||
|
||||
### New custom types
|
||||
|
||||
- `SELVA_MODEL` — `{generator, video_enc, feature_utils, variant, strategy, dtype}`
|
||||
- `SELVA_FEATURES` — `{clip_features, sync_features, duration}`
|
||||
|
||||
### No subprocess
|
||||
|
||||
SelVA is pure PyTorch. Feature extraction runs inline in ComfyUI — no managed venv,
|
||||
no JAX/TF, no pip install on first run.
|
||||
|
||||
### Dependencies
|
||||
|
||||
Zero new pip packages. ComfyUI already ships:
|
||||
- `open_clip_torch` (CLIP ViT-H-14-384, auto-downloads via `hf-hub:` on first use)
|
||||
- `transformers` (flan-t5-base, auto-downloads from HuggingFace on first use)
|
||||
- `torch`, `torchaudio`, `einops`
|
||||
|
||||
---
|
||||
|
||||
## Nodes
|
||||
|
||||
### `SelvaModelLoader` → `SELVA_MODEL`
|
||||
|
||||
| Input | Type | Default | Notes |
|
||||
|---|---|---|---|
|
||||
| variant | dropdown | medium_44k | small_16k / small_44k / medium_44k / large_44k |
|
||||
| precision | dropdown | bf16 | bf16 / fp16 / fp32 |
|
||||
| offload_strategy | dropdown | auto | auto / keep_in_vram / offload_to_cpu |
|
||||
|
||||
Resolves weights from `models/selva/`. Raises descriptive errors with download
|
||||
instructions if files are missing.
|
||||
|
||||
### `SelvaFeatureExtractor` → `SELVA_FEATURES`, `FLOAT` (fps)
|
||||
|
||||
| Input | Type | Default | Notes |
|
||||
|---|---|---|---|
|
||||
| video | IMAGE | — | ComfyUI video tensor [T,H,W,C] |
|
||||
| prompt | STRING | — | Used by TextSynchformer to select relevant motion |
|
||||
| video_info | VHS_VIDEOINFO | opt | Auto-sets fps when connected |
|
||||
| fps | FLOAT | 30.0 | Fallback fps if video_info not connected |
|
||||
| cache_dir | STRING | "" | Empty = system temp dir |
|
||||
|
||||
Feature extraction steps (all inline, no subprocess):
|
||||
1. Resize frames to 384×384 → CLIP video features `[B, T, 1024]`
|
||||
2. Resize frames to 224×224 + encode prompt with flan-T5 → TextSynchformer → text-conditioned sync features `[B, T, 768]`
|
||||
3. Save to `.npz` cache keyed by hash(frames[:1MB] + prompt + fps)
|
||||
|
||||
### `SelvaSampler` → `AUDIO`
|
||||
|
||||
| Input | Type | Default | Notes |
|
||||
|---|---|---|---|
|
||||
| model | SELVA_MODEL | — | |
|
||||
| features | SELVA_FEATURES | — | |
|
||||
| prompt | STRING | — | Should match extractor prompt; drives CLIP text guidance |
|
||||
| negative_prompt | STRING | "" | Steers away from unwanted sounds |
|
||||
| duration | FLOAT | 0.0 | 0 = auto from features duration |
|
||||
| steps | INT | 25 | Euler steps (25 is SelVA default, fast) |
|
||||
| cfg_strength | FLOAT | 4.5 | CFG scale (SelVA default) |
|
||||
| seed | INT | 0 | |
|
||||
|
||||
Generation steps:
|
||||
1. Encode prompt → CLIP text features (for MMAudio)
|
||||
2. Encode negative prompt → empty conditions for CFG
|
||||
3. `net_generator.preprocess_conditions(clip_f, sync_f, text_clip)`
|
||||
4. Flow matching Euler ODE (`num_steps` iterations) with CFG
|
||||
5. `feature_utils.decode(latent)` → mel spectrogram
|
||||
6. `feature_utils.vocode(spec)` → waveform (BigVGAN for 16k, direct for 44k)
|
||||
|
||||
**Note on dual prompt:** The extractor prompt is baked into sync_features via
|
||||
TextSynchformer at extraction time. The sampler prompt drives CLIP text conditioning
|
||||
at generation time. They should match — a tooltip explains this.
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
[VHS LoadVideo] ──► [SelvaFeatureExtractor]
|
||||
│ prompt: "dog barking"
|
||||
│ video_info: (fps auto)
|
||||
▼
|
||||
SELVA_FEATURES
|
||||
{clip_features [B,T,1024],
|
||||
sync_features [B,T,768], ← text-conditioned
|
||||
duration: 8.2s}
|
||||
│
|
||||
[SelvaModelLoader] ──► [SelvaSampler]
|
||||
variant: medium_44k │ prompt: "dog barking"
|
||||
precision: bf16 │ negative: "wind noise"
|
||||
│ cfg_strength: 4.5, steps: 25
|
||||
▼
|
||||
AUDIO (44.1kHz or 16kHz)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model Weights
|
||||
|
||||
Location: `models/selva/`
|
||||
|
||||
```
|
||||
video_enc_sup_5.pth ← TextSynch, shared across all variants
|
||||
generator_small_16k_sup_5.pth
|
||||
generator_small_44k_sup_5.pth
|
||||
generator_medium_44k_sup_5.pth
|
||||
generator_large_44k_sup_5.pth
|
||||
ext/
|
||||
v1-16.pth ← VAE for 16k variants
|
||||
v1-44.pth ← VAE for 44k variants
|
||||
best_netG.pt ← BigVGAN vocoder (16k only)
|
||||
```
|
||||
|
||||
`synchformer_state_dict.pth` is reused from `models/prismaudio/` — no duplicate.
|
||||
|
||||
---
|
||||
|
||||
## selva_core vendoring
|
||||
|
||||
Copy from `jnwnlee/selva` (pinned to a specific commit for stability):
|
||||
- `selva_core/model/` — MMAudio, TextSynch, transformer layers, embeddings, flow matching
|
||||
- `selva_core/ext/` — autoencoder, BigVGAN, synchformer, rotary embeddings, mel converters
|
||||
- `selva_core/utils/` — transforms, generate() helper
|
||||
|
||||
Rename all internal imports from `selva.*` → `selva_core.*`.
|
||||
|
||||
---
|
||||
|
||||
## What stays the same
|
||||
|
||||
- All PrismAudio nodes unchanged
|
||||
- `models/prismaudio/` unchanged
|
||||
- Synchformer checkpoint shared (not duplicated)
|
||||
- Branch: new `feature/selva-integration` off master (LoRA work stays separate)
|
||||
@@ -1,738 +0,0 @@
|
||||
# SelVA Integration Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Add three new ComfyUI nodes (SelvaModelLoader, SelvaFeatureExtractor, SelvaSampler) that run SelVA's text-conditioned V2A pipeline inline — no subprocess, no JAX, pure PyTorch.
|
||||
|
||||
**Architecture:** Vendor SelVA source into `selva_core/`, implement three nodes that mirror the PrismAudio pattern. `SelvaFeatureExtractor` takes `SELVA_MODEL` (needs TextSynchformer + CLIP/T5 from FeaturesUtils). `SelvaSampler` runs flow matching ODE with CFG and negative prompts.
|
||||
|
||||
**Tech Stack:** PyTorch, open_clip (already in ComfyUI), transformers (already in ComfyUI), torchaudio, einops, torchvision
|
||||
|
||||
---
|
||||
|
||||
## Design reference
|
||||
|
||||
`docs/plans/2026-04-04-selva-integration-design.md`
|
||||
|
||||
**Key facts from SelVA source:**
|
||||
- CLIP input: `[B, T, C, 384, 384]` float32 `[0,1]` — normalization applied inside FeaturesUtils
|
||||
- Sync input: `[B, T, C, 224, 224]` float32 `[-1,1]` — normalize with `mean=std=[0.5,0.5,0.5]` before passing
|
||||
- CLIP frame rate: 8fps, Sync frame rate: 25fps
|
||||
- CONFIG_16K: latent=250, clip=64, sync=192 at 8s
|
||||
- CONFIG_44K: latent=345, clip=64, sync=192 at 8s
|
||||
- Sync segments: 16-frame windows, 8-frame stride (overlapping, unlike PrismAudio's 8-frame non-overlapping)
|
||||
- `net_generator.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)` must be called before each generation when duration ≠ 8s
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Create branch and vendor selva_core
|
||||
|
||||
**Files:**
|
||||
- Create: `selva_core/` (full directory tree)
|
||||
|
||||
**Step 1: Create new branch off master (not off feature/lora-trainer)**
|
||||
|
||||
```bash
|
||||
git checkout master
|
||||
git checkout -b feature/selva-integration
|
||||
```
|
||||
|
||||
**Step 2: Clone SelVA and copy source**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/jnwnlee/selva.git /tmp/selva_src
|
||||
cp -r /tmp/selva_src/selva /media/p5/Comfyui-Prismaudio/selva_core
|
||||
```
|
||||
|
||||
**Step 3: Rename all internal imports**
|
||||
|
||||
```bash
|
||||
cd /media/p5/Comfyui-Prismaudio/selva_core
|
||||
find . -name "*.py" -exec sed -i \
|
||||
's/from selva\./from selva_core./g;
|
||||
s/import selva\./import selva_core./g' {} \;
|
||||
```
|
||||
|
||||
**Step 4: Record the pinned commit**
|
||||
|
||||
```bash
|
||||
cd /tmp/selva_src && git rev-parse HEAD
|
||||
# Paste the hash into a comment at the top of selva_core/__init__.py
|
||||
```
|
||||
|
||||
Edit `selva_core/__init__.py` to add at the top:
|
||||
```python
|
||||
# Vendored from https://github.com/jnwnlee/selva
|
||||
# Pinned commit: <PASTE_HASH_HERE>
|
||||
# Imports rewritten from selva.* → selva_core.*
|
||||
```
|
||||
|
||||
**Step 5: Verify imports work**
|
||||
|
||||
```bash
|
||||
cd /media/p5/Comfyui-Prismaudio
|
||||
python -c "
|
||||
from selva_core.model.networks_generator import MMAudio, get_my_mmaudio
|
||||
from selva_core.model.networks_video_enc import TextSynch, get_my_textsynch
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
||||
print('selva_core imports OK')
|
||||
print(f'CONFIG_16K: latent={CONFIG_16K.latent_seq_len} clip={CONFIG_16K.clip_seq_len} sync={CONFIG_16K.sync_seq_len}')
|
||||
print(f'CONFIG_44K: latent={CONFIG_44K.latent_seq_len} clip={CONFIG_44K.clip_seq_len} sync={CONFIG_44K.sync_seq_len}')
|
||||
"
|
||||
```
|
||||
|
||||
Expected:
|
||||
```
|
||||
selva_core imports OK
|
||||
CONFIG_16K: latent=250 clip=64 sync=192
|
||||
CONFIG_44K: latent=345 clip=64 sync=192
|
||||
```
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add selva_core/
|
||||
git commit -m "chore: vendor selva_core from jnwnlee/selva@<HASH>
|
||||
|
||||
Pure PyTorch SelVA source for SelvaModelLoader/FeatureExtractor/Sampler nodes.
|
||||
Imports rewritten from selva.* to selva_core.*. No training code included."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Implement SelvaModelLoader
|
||||
|
||||
**Files:**
|
||||
- Create: `nodes/selva_model_loader.py`
|
||||
- Modify: `nodes/__init__.py`
|
||||
|
||||
**Step 1: Create `nodes/selva_model_loader.py`**
|
||||
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
|
||||
# Variant → (generator filename, mode, has_bigvgan)
|
||||
_VARIANTS = {
|
||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||
}
|
||||
|
||||
_SELVA_DIR = os.path.join(folder_paths.models_dir, "selva")
|
||||
|
||||
|
||||
def _selva_path(*parts):
|
||||
return os.path.join(_SELVA_DIR, *parts)
|
||||
|
||||
|
||||
def _require(path, hint):
|
||||
if not os.path.exists(path):
|
||||
raise RuntimeError(
|
||||
f"[SelVA] Missing: {path}\n{hint}"
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
class SelvaModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"variant": (list(_VARIANTS.keys()),),
|
||||
"precision": (["bf16", "fp16", "fp32"],),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_model(self, variant, precision, offload_strategy):
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||
|
||||
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Resolve weight paths
|
||||
video_enc_path = _require(
|
||||
_selva_path("video_enc_sup_5.pth"),
|
||||
"Download from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||
)
|
||||
gen_path = _require(
|
||||
_selva_path(gen_filename),
|
||||
f"Download {gen_filename} from https://huggingface.co/jnwnlee/selva and place in models/selva/"
|
||||
)
|
||||
vae_path = _require(
|
||||
_selva_path("ext", f"v1-{mode}.pth"),
|
||||
f"Download v1-{mode}.pth from MMAudio/SelVA release and place in models/selva/ext/"
|
||||
)
|
||||
synch_path = _require(
|
||||
os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth"),
|
||||
"Synchformer checkpoint missing from models/prismaudio/ — download from FunAudioLLM/PrismAudio"
|
||||
)
|
||||
bigvgan_path = None
|
||||
if has_bigvgan:
|
||||
bigvgan_path = _require(
|
||||
_selva_path("ext", "best_netG.pt"),
|
||||
"Download best_netG.pt (BigVGAN 16k vocoder) from MMAudio release and place in models/selva/ext/"
|
||||
)
|
||||
|
||||
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||
net_video_enc.load_weights(
|
||||
torch.load(video_enc_path, map_location="cpu", weights_only=True)
|
||||
)
|
||||
|
||||
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
||||
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
||||
net_generator.load_weights(
|
||||
torch.load(gen_path, map_location="cpu", weights_only=True)
|
||||
)
|
||||
|
||||
print(f"[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
||||
feature_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=vae_path,
|
||||
synchformer_ckpt=synch_path,
|
||||
enable_conditions=True,
|
||||
mode=mode,
|
||||
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||
).to(device, dtype).eval()
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
|
||||
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
||||
|
||||
return ({
|
||||
"generator": net_generator,
|
||||
"video_enc": net_video_enc,
|
||||
"feature_utils": feature_utils,
|
||||
"variant": variant,
|
||||
"mode": mode,
|
||||
"strategy": strategy,
|
||||
"dtype": dtype,
|
||||
"seq_cfg": seq_cfg,
|
||||
},)
|
||||
```
|
||||
|
||||
**Step 2: Register in `nodes/__init__.py`**
|
||||
|
||||
In the `NODE_CLASS_MAPPINGS` dict, add:
|
||||
```python
|
||||
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||
```
|
||||
|
||||
**Step 3: Verify node registers**
|
||||
|
||||
```bash
|
||||
cd /media/p5/Comfyui-Prismaudio
|
||||
python -c "
|
||||
import sys; sys.path.insert(0, '.')
|
||||
from nodes.selva_model_loader import SelvaModelLoader
|
||||
print('inputs:', list(SelvaModelLoader.INPUT_TYPES()['required'].keys()))
|
||||
print('outputs:', SelvaModelLoader.RETURN_TYPES)
|
||||
"
|
||||
```
|
||||
|
||||
Expected: `inputs: ['variant', 'precision', 'offload_strategy']`
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add nodes/selva_model_loader.py nodes/__init__.py
|
||||
git commit -m "feat: SelvaModelLoader node — loads TextSynch + MMAudio + FeaturesUtils"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Implement SelvaFeatureExtractor
|
||||
|
||||
**Files:**
|
||||
- Create: `nodes/selva_feature_extractor.py`
|
||||
- Modify: `nodes/__init__.py`
|
||||
|
||||
**Step 1: Create `nodes/selva_feature_extractor.py`**
|
||||
|
||||
```python
|
||||
import os
|
||||
import hashlib
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||
_CLIP_SIZE = 384
|
||||
_SYNC_SIZE = 224
|
||||
_CLIP_FPS = 8
|
||||
_SYNC_FPS = 25
|
||||
|
||||
# Sync normalization: [-1, 1] (from selva/utils/eval_utils.py load_video)
|
||||
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
|
||||
|
||||
def _sample_frames(video, source_fps, target_fps, duration):
|
||||
"""Sample frames from [T,H,W,C] float32 [0,1] at target_fps."""
|
||||
T = video.shape[0]
|
||||
n_out = max(1, int(duration * target_fps))
|
||||
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||
return video[indices] # [N, H, W, C]
|
||||
|
||||
|
||||
def _resize_frames(frames, size):
|
||||
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
||||
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
||||
x = F.interpolate(x, size=(size, size), mode="bicubic", align_corners=False)
|
||||
return x.clamp(0, 1) # [N, C, H, W] float32
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, prompt, fps, variant):
|
||||
h = hashlib.sha256()
|
||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024])
|
||||
h.update(prompt.encode())
|
||||
h.update(str(fps).encode())
|
||||
h.update(variant.encode())
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
class SelvaFeatureExtractor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"video": ("IMAGE",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||
"tooltip": "Text prompt used by TextSynchformer to focus sync features on the relevant sound source. Should match the prompt used in SelvaSampler."}),
|
||||
},
|
||||
"optional": {
|
||||
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info to auto-set fps."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001}),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||
"tooltip": "Override duration in seconds. 0 = infer from video length and fps."}),
|
||||
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory for cached .npz features. Empty = temp dir."}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT")
|
||||
RETURN_NAMES = ("features", "fps")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir=""):
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
T = video.shape[0]
|
||||
if duration <= 0:
|
||||
duration = T / fps
|
||||
duration = min(duration, T / fps) # clamp to actual video length
|
||||
|
||||
if not prompt.strip():
|
||||
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
||||
|
||||
# Cache
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_key = _hash_inputs(video, prompt, fps, model["variant"])
|
||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||
|
||||
if os.path.exists(cached_path):
|
||||
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||
return (_load_cached(cached_path), float(fps))
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
feature_utils = model["feature_utils"]
|
||||
net_video_enc = model["video_enc"]
|
||||
|
||||
# Move feature models to device
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(device)
|
||||
net_video_enc.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# --- CLIP frames: 384×384, [0,1], 8fps ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
|
||||
# --- Sync frames: 224×224, [-1,1], 25fps ---
|
||||
n_sync = max(16, int(duration * _SYNC_FPS)) # minimum 16 for segmentation
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration)
|
||||
if sync_frames.shape[0] < 16:
|
||||
# Pad by repeating last frame to reach minimum 16
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
# Normalize to [-1, 1]
|
||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps", flush=True)
|
||||
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f_t5, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, 768], [1, L]
|
||||
text_f_t5, text_mask = net_video_enc.prepend_sup_text_tokens(text_f_t5, text_mask)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_f_t5, text_mask=text_mask
|
||||
) # [1, T_sync, 768]
|
||||
|
||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||
|
||||
# Offload back if needed
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Save cache
|
||||
np.savez(
|
||||
cached_path,
|
||||
clip_features=clip_features.cpu().float().numpy(),
|
||||
sync_features=sync_features.cpu().float().numpy(),
|
||||
duration=duration,
|
||||
)
|
||||
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||
|
||||
features = {
|
||||
"clip_features": clip_features.cpu(),
|
||||
"sync_features": sync_features.cpu(),
|
||||
"duration": duration,
|
||||
}
|
||||
return (features, float(fps))
|
||||
|
||||
|
||||
def _load_cached(path):
|
||||
data = np.load(path, allow_pickle=False)
|
||||
return {
|
||||
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||
"duration": float(data["duration"]),
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Register in `nodes/__init__.py`**
|
||||
|
||||
```python
|
||||
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||
```
|
||||
|
||||
**Step 3: Verify node registers**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
import sys; sys.path.insert(0, '.')
|
||||
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||
inputs = SelvaFeatureExtractor.INPUT_TYPES()
|
||||
print('required:', list(inputs['required'].keys()))
|
||||
print('optional:', list(inputs['optional'].keys()))
|
||||
print('outputs:', SelvaFeatureExtractor.RETURN_TYPES)
|
||||
"
|
||||
```
|
||||
|
||||
Expected: `required: ['model', 'video', 'prompt']`
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add nodes/selva_feature_extractor.py nodes/__init__.py
|
||||
git commit -m "feat: SelvaFeatureExtractor — inline CLIP + TextSynchformer feature extraction"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Implement SelvaSampler
|
||||
|
||||
**Files:**
|
||||
- Create: `nodes/selva_sampler.py`
|
||||
- Modify: `nodes/__init__.py`
|
||||
|
||||
**Step 1: Create `nodes/selva_sampler.py`**
|
||||
|
||||
```python
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
def _make_seq_cfg(duration, mode):
|
||||
"""Compute sequence lengths for a given duration and mode."""
|
||||
from selva_core.model.sequence_config import SequenceConfig
|
||||
if mode == "16k":
|
||||
return SequenceConfig(duration=duration, sampling_rate=16000, spectrogram_frame_rate=256)
|
||||
else:
|
||||
return SequenceConfig(duration=duration, sampling_rate=44100, spectrogram_frame_rate=512)
|
||||
|
||||
|
||||
class SelvaSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"features": ("SELVA_FEATURES",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True,
|
||||
"tooltip": "Should match the prompt used in SelvaFeatureExtractor."}),
|
||||
"negative_prompt": ("STRING", {"default": "", "multiline": True,
|
||||
"tooltip": "Sounds to steer away from, e.g. 'wind noise, background music'."}),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||
"tooltip": "Audio duration in seconds. 0 = use duration from features."}),
|
||||
"steps": ("INT", {"default": 25, "min": 1, "max": 200}),
|
||||
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
net_generator = model["generator"]
|
||||
feature_utils = model["feature_utils"]
|
||||
mode = model["mode"]
|
||||
|
||||
# Resolve duration
|
||||
if duration <= 0:
|
||||
if "duration" not in features:
|
||||
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
||||
duration = features["duration"]
|
||||
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
||||
|
||||
seq_cfg = _make_seq_cfg(duration, mode)
|
||||
sample_rate = seq_cfg.sampling_rate
|
||||
|
||||
# Move models to device
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(device)
|
||||
feature_utils.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||
|
||||
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||
print(f"[SelVA] seq_cfg: latent={seq_cfg.latent_seq_len} clip={seq_cfg.clip_seq_len} sync={seq_cfg.sync_seq_len}", flush=True)
|
||||
|
||||
# Update model sequence lengths for this duration
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=seq_cfg.clip_seq_len,
|
||||
sync_seq_len=seq_cfg.sync_seq_len,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode text
|
||||
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||
|
||||
# Build empty (negative) conditions for CFG
|
||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||
if negative_prompt.strip() else None
|
||||
|
||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs=1, negative_text_features=neg_text_clip
|
||||
)
|
||||
|
||||
# Sample initial noise
|
||||
rng = torch.Generator(device=device).manual_seed(seed)
|
||||
x0 = torch.randn(
|
||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||
device=device, dtype=dtype, generator=rng
|
||||
)
|
||||
|
||||
# Flow matching ODE (Euler)
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
_step_count = [0]
|
||||
orig_to_data = fm.to_data
|
||||
|
||||
def tracked_to_data(fn, x0_):
|
||||
# ProgressBar update via step counting in ode_wrapper
|
||||
return orig_to_data(fn, x0_)
|
||||
|
||||
# Wrap ODE to update progress bar
|
||||
def ode_wrapper_tracked(t, x):
|
||||
_step_count[0] += 1
|
||||
pbar.update(1)
|
||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||
|
||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
# Decode: latent → mel → audio
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad():
|
||||
x1_unnorm = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1_unnorm)
|
||||
audio = feature_utils.vocode(spec) # [1, samples] or [1, 1, samples]
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Normalise to [-1, 1]
|
||||
audio = audio.float()
|
||||
if audio.dim() == 2:
|
||||
audio = audio.unsqueeze(1) # [1, 1, samples]
|
||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||
|
||||
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||
```
|
||||
|
||||
**Step 2: Register in `nodes/__init__.py`**
|
||||
|
||||
```python
|
||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||
```
|
||||
|
||||
**Step 3: Verify node registers**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
import sys; sys.path.insert(0, '.')
|
||||
from nodes.selva_sampler import SelvaSampler
|
||||
inputs = SelvaSampler.INPUT_TYPES()
|
||||
print('inputs:', list(inputs['required'].keys()))
|
||||
print('outputs:', SelvaSampler.RETURN_TYPES)
|
||||
"
|
||||
```
|
||||
|
||||
Expected: `inputs: ['model', 'features', 'prompt', 'negative_prompt', 'duration', 'steps', 'cfg_strength', 'seed']`
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add nodes/selva_sampler.py nodes/__init__.py
|
||||
git commit -m "feat: SelvaSampler — flow matching ODE with CFG + negative prompts"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Create example workflow and push
|
||||
|
||||
**Files:**
|
||||
- Create: `workflows/selva_video_to_audio.json`
|
||||
|
||||
**Step 1: Create workflow JSON**
|
||||
|
||||
Create `workflows/selva_video_to_audio.json` with this node graph:
|
||||
- LoadVideo (VHS) → IMAGE + VHS_VIDEOINFO
|
||||
- SelvaModelLoader → SELVA_MODEL
|
||||
- SelvaFeatureExtractor (takes IMAGE + VHS_VIDEOINFO + SELVA_MODEL, prompt) → SELVA_FEATURES
|
||||
- SelvaSampler (takes SELVA_MODEL + SELVA_FEATURES, prompt, negative_prompt) → AUDIO
|
||||
- PreviewAudio (takes AUDIO)
|
||||
|
||||
Set defaults: variant=medium_44k, precision=bf16, steps=25, cfg_strength=4.5, duration=0.
|
||||
|
||||
**Step 2: Push branch**
|
||||
|
||||
```bash
|
||||
git push -u origin feature/selva-integration
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Smoke test
|
||||
|
||||
**Step 1: Check all three nodes are importable from ComfyUI's perspective**
|
||||
|
||||
```bash
|
||||
cd /media/p5/Comfyui-Prismaudio
|
||||
python -c "
|
||||
import sys; sys.path.insert(0, '.')
|
||||
import nodes
|
||||
m = nodes.NODE_CLASS_MAPPINGS
|
||||
print('SelVA nodes:', [k for k in m if 'Selva' in k])
|
||||
assert 'SelvaModelLoader' in m
|
||||
assert 'SelvaFeatureExtractor' in m
|
||||
assert 'SelvaSampler' in m
|
||||
print('All SelVA nodes registered OK')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 2: Verify no import errors in full node load**
|
||||
|
||||
```bash
|
||||
python -c "
|
||||
import sys; sys.path.insert(0, '.')
|
||||
from nodes.selva_model_loader import SelvaModelLoader
|
||||
from nodes.selva_feature_extractor import SelvaFeatureExtractor
|
||||
from nodes.selva_sampler import SelvaSampler
|
||||
print('All imports clean')
|
||||
"
|
||||
```
|
||||
|
||||
**Step 3: Final commit with any fixes**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "fix: selva integration smoke test fixes (if any)"
|
||||
git push
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
- The `FeaturesUtils.train()` is overridden to always call `super().train(False)` — SelVA models are always in eval mode
|
||||
- `net_generator.update_seq_lengths` recalculates rotary position embeddings; call it before every generation when duration may vary
|
||||
- ProgressBar tracking: `FlowMatching.to_data` calls `fn(t, x)` for each Euler step; wrapping `ode_wrapper` with a counter gives accurate progress
|
||||
- The `feature_utils.vocode` returns audio at 16kHz for small_16k (uses BigVGAN) and 44.1kHz for 44k variants (uses VAE mel decoder directly)
|
||||
- If `encode_text_t5` or `encode_text_clip` fail with missing model errors on first run, it's HuggingFace downloading `flan-t5-base` and `apple/DFN5B-CLIP-ViT-H-14-384` — this is expected and takes a few minutes once
|
||||
+4
-8
@@ -2,13 +2,9 @@ NODE_CLASS_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
|
||||
_NODES = {
|
||||
"PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"),
|
||||
"PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"),
|
||||
"PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"),
|
||||
"PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"),
|
||||
"PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"),
|
||||
"PrismAudioLoRATrainer": (".lora_trainer", "PrismAudioLoRATrainer", "PrismAudio LoRA Trainer"),
|
||||
"PrismAudioLoRALoader": (".lora_loader", "PrismAudioLoRALoader", "PrismAudio LoRA Loader"),
|
||||
"SelvaModelLoader": (".selva_model_loader", "SelvaModelLoader", "SelVA Model Loader"),
|
||||
"SelvaFeatureExtractor": (".selva_feature_extractor", "SelvaFeatureExtractor", "SelVA Feature Extractor"),
|
||||
"SelvaSampler": (".selva_sampler", "SelvaSampler", "SelVA Sampler"),
|
||||
}
|
||||
|
||||
for key, (module_path, class_name, display_name) in _NODES.items():
|
||||
@@ -18,4 +14,4 @@ for key, (module_path, class_name, display_name) in _NODES.items():
|
||||
NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name)
|
||||
NODE_DISPLAY_NAME_MAPPINGS[key] = display_name
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"[PrismAudio] Skipping {key}: {e}")
|
||||
print(f"[SelVA] Skipping {key}: {e}")
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
from .feature_loader import PrismAudioFeatureLoader
|
||||
|
||||
# Managed venv created automatically when python_env is left as default
|
||||
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||
|
||||
def _jax_package():
|
||||
"""Return the correct jax extra for the current CUDA version."""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
cuda_ver = torch.version.cuda or ""
|
||||
major = int(cuda_ver.split(".")[0]) if cuda_ver else 0
|
||||
if major >= 13:
|
||||
return "jax[cuda13]"
|
||||
elif major >= 12:
|
||||
return "jax[cuda12]"
|
||||
except Exception:
|
||||
pass
|
||||
return "jax" # CPU fallback
|
||||
|
||||
|
||||
_EXTRACT_PACKAGES = [
|
||||
"torch", "torchaudio", "torchvision",
|
||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||
"tensorflow-cpu>=2.16.0",
|
||||
# jax CUDA extra is resolved at install time based on detected CUDA version
|
||||
_jax_package(), "flax",
|
||||
"transformers", "decord", "einops", "numpy",
|
||||
"git+https://github.com/google-deepmind/videoprism.git",
|
||||
]
|
||||
|
||||
|
||||
def _pip_install(pip, *packages, label=None):
|
||||
"""Install one or more packages with visible output; raise on failure."""
|
||||
tag = label or packages[0]
|
||||
print(f"[PrismAudio] installing {tag} ...", flush=True)
|
||||
result = subprocess.run(
|
||||
[pip, "install", "--progress-bar", "on"] + list(packages),
|
||||
capture_output=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
|
||||
"See pip output above for details."
|
||||
)
|
||||
print(f"[PrismAudio] {tag} OK", flush=True)
|
||||
|
||||
|
||||
def _ensure_extract_env():
|
||||
"""Create and populate the managed venv on first use."""
|
||||
if os.path.exists(_MANAGED_PYTHON):
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
import shutil
|
||||
if os.path.exists(_MANAGED_VENV):
|
||||
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
|
||||
shutil.rmtree(_MANAGED_VENV)
|
||||
|
||||
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
|
||||
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
|
||||
|
||||
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
|
||||
|
||||
print("[PrismAudio] Upgrading pip...", flush=True)
|
||||
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
|
||||
|
||||
total = len(_EXTRACT_PACKAGES)
|
||||
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
|
||||
|
||||
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
|
||||
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
|
||||
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
|
||||
_pip_install(pip, pkg, label=label)
|
||||
|
||||
print("[PrismAudio] Feature-extraction env ready.", flush=True)
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, cot_text, fps):
|
||||
"""Create a hash of the inputs for caching."""
|
||||
h = hashlib.sha256()
|
||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||
h.update(cot_text.encode())
|
||||
h.update(str(fps).encode()) # fps affects frame sampling — must be part of the key
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def _save_frames_to_npy(video_tensor, output_path):
|
||||
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
||||
|
||||
Lossless — avoids H.264 encode/decode roundtrip.
|
||||
"""
|
||||
import numpy as np
|
||||
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
||||
np.save(output_path, frames_np)
|
||||
|
||||
|
||||
class PrismAudioFeatureExtractor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": ("IMAGE",),
|
||||
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
||||
},
|
||||
"optional": {
|
||||
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
|
||||
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
|
||||
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
||||
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
|
||||
RETURN_NAMES = ("features", "fps")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
|
||||
# Resolve fps from VHS video_info if connected
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
if not caption_cot.strip():
|
||||
print("[PrismAudio] Warning: caption_cot is empty — text features will be degenerate. "
|
||||
"Provide a descriptive chain-of-thought caption for best results.", flush=True)
|
||||
|
||||
# Resolve python binary
|
||||
if python_env == "comfyui_env":
|
||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
|
||||
python_bin = sys.executable
|
||||
else:
|
||||
python_bin = _ensure_extract_env()
|
||||
|
||||
# Determine cache directory
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Check cache
|
||||
cache_hash = _hash_inputs(video, caption_cot, fps)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||
if os.path.exists(cached_path):
|
||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
|
||||
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
||||
import time
|
||||
t0 = time.perf_counter()
|
||||
frames = video.shape[0]
|
||||
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
||||
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
||||
tmp_video = tmp.name
|
||||
_save_frames_to_npy(video, tmp_video)
|
||||
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||
|
||||
# Build subprocess command
|
||||
script_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"scripts", "extract_features.py"
|
||||
)
|
||||
|
||||
import folder_paths
|
||||
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
|
||||
if not os.path.exists(synchformer_ckpt):
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
|
||||
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
|
||||
)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
script_path,
|
||||
"--video", tmp_video,
|
||||
"--cot_text", caption_cot,
|
||||
"--output", cached_path,
|
||||
"--source_fps", str(fps),
|
||||
"--synchformer_ckpt", synchformer_ckpt,
|
||||
]
|
||||
|
||||
# Build env: inherit current env, inject HF token if provided
|
||||
import copy
|
||||
env = copy.copy(os.environ)
|
||||
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
|
||||
if token:
|
||||
env["HF_TOKEN"] = token
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = token
|
||||
else:
|
||||
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
|
||||
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
|
||||
|
||||
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
|
||||
try:
|
||||
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=False,
|
||||
timeout=600, # 10 minute timeout
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
|
||||
"See output above for details."
|
||||
)
|
||||
print("[PrismAudio] Feature extraction subprocess finished successfully.")
|
||||
finally:
|
||||
if os.path.exists(tmp_video):
|
||||
os.unlink(tmp_video)
|
||||
|
||||
# Load the extracted features
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
|
||||
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
||||
# global_video_features and global_text_features are NOT consumed by any conditioner
|
||||
# in the prismaudio.json config — they are unused.
|
||||
REQUIRED_KEYS = [
|
||||
"video_features",
|
||||
"text_features",
|
||||
"sync_features",
|
||||
]
|
||||
|
||||
|
||||
class PrismAudioFeatureLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||
RETURN_NAMES = ("features",)
|
||||
FUNCTION = "load_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_features(self, npz_path):
|
||||
if not os.path.exists(npz_path):
|
||||
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
||||
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
|
||||
features = {}
|
||||
for key in REQUIRED_KEYS:
|
||||
if key in data:
|
||||
features[key] = torch.from_numpy(data[key]).float()
|
||||
else:
|
||||
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
||||
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
||||
if key == "sync_features":
|
||||
features[key] = torch.zeros(8, 768)
|
||||
else:
|
||||
features[key] = torch.zeros(1, 1024)
|
||||
|
||||
# Load duration if present
|
||||
if "duration" in data:
|
||||
features["duration"] = float(data["duration"])
|
||||
|
||||
return (features,)
|
||||
@@ -1,106 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
|
||||
|
||||
def _merge_lora_weights(dit: nn.Module, lora_state: dict, rank: int, alpha: float, strength: float):
|
||||
"""Add LoRA delta weights directly into the base model's nn.Linear tensors.
|
||||
|
||||
delta_W = lora_B @ lora_A * scale * strength
|
||||
applied as: linear.weight += delta_W
|
||||
|
||||
This is equivalent to LoRALinear at inference but requires no wrapper,
|
||||
no extra memory, and no change to the model's forward call graph.
|
||||
"""
|
||||
scale = (alpha / rank) * strength
|
||||
|
||||
# Group saved keys by module path
|
||||
a_map = {
|
||||
k.replace(".lora_A.weight", ""): v
|
||||
for k, v in lora_state.items() if k.endswith("lora_A.weight")
|
||||
}
|
||||
b_map = {
|
||||
k.replace(".lora_B.weight", ""): v
|
||||
for k, v in lora_state.items() if k.endswith("lora_B.weight")
|
||||
}
|
||||
|
||||
merged = 0
|
||||
for path, lora_A in a_map.items():
|
||||
if path not in b_map:
|
||||
print(f"[PrismAudio] LoRA merge: missing lora_B for {path}, skipping", flush=True)
|
||||
continue
|
||||
lora_B = b_map[path] # [out_features, rank]
|
||||
# delta_W: [out_features, in_features]
|
||||
delta_W = (lora_B.float() @ lora_A.float()) * scale
|
||||
|
||||
# Navigate to the parent module using PyTorch's get_submodule
|
||||
*parent_parts, child_name = path.split(".")
|
||||
try:
|
||||
parent = dit.get_submodule(".".join(parent_parts)) if parent_parts else dit
|
||||
except AttributeError as e:
|
||||
print(f"[PrismAudio] LoRA merge: could not find module '{path}': {e}", flush=True)
|
||||
continue
|
||||
|
||||
linear = getattr(parent, child_name, None)
|
||||
if not isinstance(linear, nn.Linear):
|
||||
print(f"[PrismAudio] LoRA merge: expected nn.Linear at '{path}', got {type(linear)}", flush=True)
|
||||
continue
|
||||
|
||||
linear.weight.data.add_(delta_W.to(linear.weight.dtype))
|
||||
merged += 1
|
||||
|
||||
print(f"[PrismAudio] LoRA merged {merged} layer(s) (strength={strength:.3f})", flush=True)
|
||||
|
||||
|
||||
class PrismAudioLoRALoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"lora_path": ("STRING", {"default": "", "tooltip": "Path to .safetensors LoRA file produced by PrismAudio LoRA Trainer"}),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "LoRA influence scale. 1.0 = full strength, 0.0 = base model only"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_lora"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_lora(self, model, lora_path, strength):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if not os.path.exists(lora_path):
|
||||
raise FileNotFoundError(f"[PrismAudio] LoRA file not found: {lora_path}")
|
||||
|
||||
config_path = lora_path.replace(".safetensors", "_config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(
|
||||
f"[PrismAudio] LoRA config not found: {config_path}\n"
|
||||
"Expected a _config.json alongside the .safetensors file."
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
rank = config["rank"]
|
||||
alpha = config["alpha"]
|
||||
|
||||
lora_state = load_file(lora_path)
|
||||
|
||||
# Merge LoRA weights in-place into the DiT's base linear layers.
|
||||
# ComfyUI re-executes the upstream ModelLoader on the next queue run
|
||||
# when inputs change, providing a fresh base model as needed.
|
||||
dit = model["model"].model # DiTWrapper
|
||||
|
||||
if strength == 0.0:
|
||||
print("[PrismAudio] LoRA strength=0.0 — skipping merge, base model unchanged.", flush=True)
|
||||
return (model,)
|
||||
|
||||
_merge_lora_weights(dit, lora_state, rank, alpha, strength)
|
||||
|
||||
return (model,)
|
||||
@@ -1,284 +0,0 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LoRA primitives
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""Low-rank adapter wrapping a frozen nn.Linear."""
|
||||
|
||||
def __init__(self, linear: nn.Linear, rank: int, alpha: float):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.scale = alpha / rank
|
||||
in_f, out_f = linear.in_features, linear.out_features
|
||||
self.lora_A = nn.Linear(in_f, rank, bias=False)
|
||||
self.lora_B = nn.Linear(rank, out_f, bias=False)
|
||||
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B.weight)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.lora_B(self.lora_A(x)) * self.scale
|
||||
|
||||
|
||||
_TARGET_MODULE_PRESETS = {
|
||||
"attn_only": {"to_q", "to_kv", "to_qkv", "to_out"},
|
||||
"attn_ffn": {"to_q", "to_kv", "to_qkv", "to_out", "proj"},
|
||||
"full": {"to_q", "to_kv", "to_qkv", "to_out", "proj", "project_in", "project_out"},
|
||||
}
|
||||
|
||||
|
||||
def _apply_lora(module: nn.Module, target_attrs: set, rank: int, alpha: float):
|
||||
"""Recursively replace matching nn.Linear layers with LoRALinear."""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, nn.Linear) and name in target_attrs:
|
||||
setattr(module, name, LoRALinear(child, rank, alpha))
|
||||
else:
|
||||
_apply_lora(child, target_attrs, rank, alpha)
|
||||
|
||||
|
||||
def _unapply_lora(module: nn.Module):
|
||||
"""Replace LoRALinear back with the original frozen Linear (no weight merge)."""
|
||||
for name, child in list(module.named_children()):
|
||||
if isinstance(child, LoRALinear):
|
||||
child.linear.weight.requires_grad_(False)
|
||||
setattr(module, name, child.linear)
|
||||
else:
|
||||
_unapply_lora(child)
|
||||
|
||||
|
||||
def _get_lora_state_dict(module: nn.Module) -> dict:
|
||||
"""Return only LoRA parameter tensors from a module's state dict."""
|
||||
return {k: v for k, v in module.state_dict().items()
|
||||
if "lora_A" in k or "lora_B" in k}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_AUDIO_EXTS = (".wav", ".flac", ".mp3")
|
||||
|
||||
|
||||
def _scan_dataset(dataset_dir: str):
|
||||
"""Return list of (npz_path, audio_path) pairs matched by stem."""
|
||||
pairs = []
|
||||
for fname in os.listdir(dataset_dir):
|
||||
if not fname.endswith(".npz"):
|
||||
continue
|
||||
stem = os.path.join(dataset_dir, fname[:-4])
|
||||
for ext in _AUDIO_EXTS:
|
||||
audio_path = stem + ext
|
||||
if os.path.exists(audio_path):
|
||||
pairs.append((stem + ".npz", audio_path))
|
||||
break
|
||||
return sorted(pairs)
|
||||
|
||||
|
||||
def _load_audio(audio_path: str, device: torch.device) -> torch.Tensor:
|
||||
"""Load audio to [1, 2, samples] float32 tensor at SAMPLE_RATE."""
|
||||
import torchaudio
|
||||
waveform, sr = torchaudio.load(audio_path)
|
||||
if sr != SAMPLE_RATE:
|
||||
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = waveform.expand(2, -1)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform.unsqueeze(0).to(device) # [1, 2, samples]
|
||||
|
||||
|
||||
def _load_metadata(npz_path: str, device: torch.device, dtype: torch.dtype) -> dict:
|
||||
"""Load .npz features into a conditioner metadata dict."""
|
||||
import numpy as np
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
video_feat = torch.from_numpy(data["video_features"]).float().to(device, dtype=dtype)
|
||||
text_feat = torch.from_numpy(data["text_features"]).float().to(device, dtype=dtype)
|
||||
sync_feat = torch.from_numpy(data["sync_features"]).float().to(device, dtype=dtype)
|
||||
has_video = bool(video_feat.abs().sum() > 0)
|
||||
return {
|
||||
"video_features": video_feat,
|
||||
"text_features": text_feat,
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trainer node
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PrismAudioLoRATrainer:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"dataset_dir": ("STRING", {"default": "", "tooltip": "Directory containing paired .npz feature files and .wav/.flac audio files (matched by filename stem)"}),
|
||||
"output_path": ("STRING", {"default": "", "tooltip": "Save path for .safetensors weights. Empty = models/prismaudio/lora/"}),
|
||||
"lora_rank": ("INT", {"default": 64, "min": 1, "max": 512}),
|
||||
"lora_alpha": ("FLOAT", {"default": 64.0, "min": 1.0, "max": 1024.0}),
|
||||
"target_modules": (["attn_ffn", "attn_only", "full"], {"tooltip": "attn_only: Q/K/V/out only. attn_ffn: + FFN input (recommended). full: + transformer I/O projections"}),
|
||||
"learning_rate": ("FLOAT", {"default": 1e-4, "min": 1e-7, "max": 1e-2, "step": 1e-6}),
|
||||
"train_steps": ("INT", {"default": 1000, "min": 1, "max": 100000}),
|
||||
"cfg_dropout_prob": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.01, "tooltip": "Probability of dropping conditioning per step — preserves CFG ability at inference"}),
|
||||
"save_every": ("INT", {"default": 500, "min": 1, "max": 100000, "tooltip": "Save a checkpoint every N steps (in addition to final save)"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("lora_path",)
|
||||
FUNCTION = "train"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def train(self, model, dataset_dir, output_path, lora_rank, lora_alpha,
|
||||
target_modules, learning_rate, train_steps, cfg_dropout_prob, save_every, seed):
|
||||
from safetensors.torch import save_file
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
diffusion = model["model"]
|
||||
strategy = model["strategy"]
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Scan dataset
|
||||
pairs = _scan_dataset(dataset_dir)
|
||||
if not pairs:
|
||||
raise RuntimeError(f"[PrismAudio] No (.npz + audio) pairs found in: {dataset_dir}")
|
||||
print(f"[PrismAudio] LoRA training — {len(pairs)} sample(s), {train_steps} steps", flush=True)
|
||||
|
||||
# Resolve output path
|
||||
if not output_path:
|
||||
import folder_paths
|
||||
out_dir = os.path.join(folder_paths.models_dir, "prismaudio", "lora")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
output_path = os.path.join(out_dir, f"prismaudio_lora_r{lora_rank}.safetensors")
|
||||
|
||||
# Move model to device
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# Freeze all DiT params, then apply LoRA (adds trainable lora_A/lora_B)
|
||||
dit = diffusion.model # DiTWrapper
|
||||
for p in dit.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
target_attrs = _TARGET_MODULE_PRESETS[target_modules]
|
||||
_apply_lora(dit, target_attrs, lora_rank, lora_alpha)
|
||||
|
||||
# Cast LoRA params to model dtype and move to device
|
||||
for m in dit.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_A.to(device=device, dtype=dtype)
|
||||
m.lora_B.to(device=device, dtype=dtype)
|
||||
|
||||
trainable = [p for p in dit.parameters() if p.requires_grad]
|
||||
n_params = sum(p.numel() for p in trainable)
|
||||
print(f"[PrismAudio] LoRA trainable params: {n_params:,} ({n_params/1e6:.2f}M)", flush=True)
|
||||
|
||||
diffusion.conditioner.eval()
|
||||
diffusion.pretransform.eval()
|
||||
dit.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(trainable, lr=learning_rate)
|
||||
|
||||
# GradScaler for fp16 to prevent underflow
|
||||
use_scaler = (dtype == torch.float16)
|
||||
scaler = torch.cuda.amp.GradScaler() if use_scaler else None
|
||||
|
||||
pbar = comfy.utils.ProgressBar(train_steps)
|
||||
|
||||
try:
|
||||
for step in range(1, train_steps + 1):
|
||||
npz_path, audio_path = random.choice(pairs)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode audio to latent space
|
||||
audio = _load_audio(audio_path, device)
|
||||
x0 = diffusion.pretransform.encode(audio.float()).to(dtype) # [1, 64, L]
|
||||
|
||||
# Build conditioning from features
|
||||
metadata = (_load_metadata(npz_path, device, dtype),)
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Rectified flow: interpolate between data and noise
|
||||
t = torch.rand(x0.shape[0], device=device, dtype=dtype) # [1]
|
||||
noise = torch.randn_like(x0)
|
||||
# t expanded for broadcast: [1] -> [1, 1, 1]
|
||||
t_bcast = t[:, None, None]
|
||||
x_t = (1.0 - t_bcast) * x0 + t_bcast * noise
|
||||
v_target = noise - x0
|
||||
|
||||
with torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
v_pred = dit(x_t, t,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
**cond_inputs)
|
||||
|
||||
loss = F.mse_loss(v_pred.float(), v_target.float())
|
||||
|
||||
if use_scaler:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % 50 == 0:
|
||||
print(f"[PrismAudio] step {step}/{train_steps} loss={loss.item():.6f}", flush=True)
|
||||
|
||||
if step % save_every == 0:
|
||||
ckpt_path = output_path.replace(".safetensors", f"_step{step}.safetensors")
|
||||
save_file(_get_lora_state_dict(dit), ckpt_path)
|
||||
print(f"[PrismAudio] Checkpoint: {ckpt_path}", flush=True)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Save final weights
|
||||
save_file(_get_lora_state_dict(dit), output_path)
|
||||
|
||||
# Save config alongside weights so the loader knows the structure
|
||||
config_path = output_path.replace(".safetensors", "_config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump({
|
||||
"rank": lora_rank,
|
||||
"alpha": lora_alpha,
|
||||
"target_modules": sorted(target_attrs),
|
||||
}, f, indent=2)
|
||||
|
||||
print(f"[PrismAudio] LoRA saved: {output_path}", flush=True)
|
||||
|
||||
finally:
|
||||
# Always restore model to base state — even on exception.
|
||||
# Without this, LoRA wrappers would persist in the cached model and
|
||||
# subsequent training runs would apply LoRA on top of existing LoRA.
|
||||
dit.eval()
|
||||
_unapply_lora(dit)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
return (output_path,)
|
||||
@@ -1,154 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
||||
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
||||
soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
|
||||
# HuggingFace repo for auto-download
|
||||
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
||||
REQUIRED_FILES = {
|
||||
"diffusion": "prismaudio.ckpt",
|
||||
"vae": "vae.ckpt",
|
||||
"synchformer": "synchformer_state_dict.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(filename, model_dir, hf_token=None):
|
||||
"""Download a model file from HuggingFace if not present locally."""
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
if os.path.exists(filepath):
|
||||
return filepath
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
||||
try:
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=filename,
|
||||
local_dir=model_dir,
|
||||
token=hf_token or None,
|
||||
)
|
||||
return downloaded
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
||||
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
||||
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
class PrismAudioModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
register_model_folder()
|
||||
return {
|
||||
"required": {
|
||||
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_model(self, precision, offload_strategy):
|
||||
device = get_device()
|
||||
dtype = determine_precision(precision, device)
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
token = resolve_hf_token()
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
|
||||
# Auto-download missing files
|
||||
for key, filename in REQUIRED_FILES.items():
|
||||
_download_if_missing(filename, model_dir, hf_token=token)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"prismaudio_core", "configs", "prismaudio.json"
|
||||
)
|
||||
with open(config_path) as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Create model from config
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
model = create_model_from_config(model_config)
|
||||
|
||||
# Load diffusion weights
|
||||
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
||||
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||
if "state_dict" in diffusion_state:
|
||||
diffusion_state = diffusion_state["state_dict"]
|
||||
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
||||
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
||||
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
||||
if diff_result.missing_keys:
|
||||
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
||||
if diff_result.unexpected_keys:
|
||||
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
||||
# Sample a few ckpt keys to verify prefix alignment
|
||||
sample_keys = list(diffusion_state.keys())[:5]
|
||||
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
||||
|
||||
# Load VAE weights separately
|
||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
||||
# Sample raw keys to see actual prefix
|
||||
vae_sample_keys = list(vae_full_state.keys())[:8]
|
||||
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
||||
# Strip "autoencoder." prefix from keys
|
||||
vae_state = {}
|
||||
prefix = "autoencoder."
|
||||
for k, v in vae_full_state.items():
|
||||
if k.startswith(prefix):
|
||||
vae_state[k[len(prefix):]] = v
|
||||
else:
|
||||
vae_state[k] = v
|
||||
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
||||
# Sample model keys to compare
|
||||
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
||||
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||
# discriminator, loss modules, and EMA wrappers not present in the
|
||||
# inference AudioAutoencoder — ignore those extra keys.
|
||||
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
||||
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
||||
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
||||
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
||||
if vae_result.missing_keys:
|
||||
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
||||
|
||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||
model.model.to(dtype) # DiTWrapper
|
||||
model.conditioner.to(dtype) # MultiConditioner
|
||||
# model.pretransform stays in fp32
|
||||
|
||||
if strategy == "keep_in_vram":
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.to(get_offload_device())
|
||||
|
||||
model.eval()
|
||||
|
||||
return ({
|
||||
"model": model,
|
||||
"dtype": dtype,
|
||||
"strategy": strategy,
|
||||
"config": model_config,
|
||||
"model_dir": model_dir,
|
||||
},)
|
||||
@@ -1,183 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
class PrismAudioSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"features": ("PRISMAUDIO_FEATURES",),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||
"sync_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.05, "tooltip": "Scale factor for sync conditioning. Higher values tighten audio-visual sync at the cost of audio naturalness; 0.0 disables sync guidance entirely."}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, features, duration, steps, cfg_scale, sync_strength, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
# Resolve duration: 0 means use video duration from features
|
||||
if duration <= 0:
|
||||
if "duration" not in features:
|
||||
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
|
||||
duration = features["duration"]
|
||||
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
|
||||
|
||||
# Compute latent dimensions
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Sync temporal coverage diagnostic
|
||||
sync_frames = features["sync_features"].shape[0]
|
||||
sync_duration_covered = sync_frames / 25.0 # Synchformer always extracts at 25fps
|
||||
print(f"[PrismAudio] sync: {sync_frames} frames @ 25fps = {sync_duration_covered:.2f}s | "
|
||||
f"audio target: {latent_length} latent frames = {duration:.2f}s", flush=True)
|
||||
if abs(sync_duration_covered - duration) > 0.5:
|
||||
print(f"[PrismAudio] Warning: sync coverage ({sync_duration_covered:.2f}s) differs from "
|
||||
f"audio duration ({duration:.2f}s) by more than 0.5s — consider re-extracting features "
|
||||
f"with the correct video duration.", flush=True)
|
||||
|
||||
# Note: no seq length config needed — the model adapts to input tensor shapes
|
||||
# dynamically via its transformer architecture.
|
||||
|
||||
# Determine if video features are present (not all zeros)
|
||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||
|
||||
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||
|
||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||
sample_meta = {
|
||||
"video_features": video_feat,
|
||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
# Move model to device if offloaded
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
# Run conditioning
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Handle missing video: substitute learned empty embeddings
|
||||
if not has_video:
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
# Scale sync conditioning after the conditioner MLP (clean linear scale,
|
||||
# avoids SiLU nonlinearity in Sync_MLP). The CFG null path always uses zeros,
|
||||
# so this directly scales the sync guidance magnitude: cfg_scale * (strength*cond - 0).
|
||||
# Only applied when video is present — T2A uses learned empty_sync_feat, not raw sync.
|
||||
if has_video and sync_strength != 1.0 and 'sync_features' in conditioning:
|
||||
conditioning['sync_features'][0] = conditioning['sync_features'][0] * sync_strength
|
||||
|
||||
# Assemble conditioning inputs for the DiT
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Sample with progress bar
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
# Offload diffusion model and conditioner before VAE decode
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
# Offload VAE
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
|
||||
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
||||
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
|
||||
|
||||
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
|
||||
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
|
||||
near-zero activations, NOT the learned null signal the model was trained with.
|
||||
|
||||
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
||||
"""
|
||||
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
||||
|
||||
# Substitute video_features with learned empty_clip_feat
|
||||
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
|
||||
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['video_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['video_features'][0] = empty_expanded
|
||||
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
|
||||
# Substitute sync_features with learned empty_sync_feat
|
||||
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
||||
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['sync_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['sync_features'][0] = empty_expanded
|
||||
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
@@ -0,0 +1,364 @@
|
||||
import os
|
||||
import hashlib
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import comfy.utils
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||
_CLIP_SIZE = 384
|
||||
_SYNC_SIZE = 224
|
||||
_CLIP_FPS = 8
|
||||
_SYNC_FPS = 25
|
||||
|
||||
# Sync normalization applied externally: maps [0,1] → [-1,1] with mean=std=0.5
|
||||
_SYNC_MEAN = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
_SYNC_STD = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1)
|
||||
|
||||
|
||||
def _sample_frames(video, source_fps, target_fps, duration):
|
||||
"""Sample frames from [T,H,W,C] float32 at target_fps; returns [N,H,W,C]."""
|
||||
T = video.shape[0]
|
||||
n_out = max(1, int(duration * target_fps))
|
||||
indices = [min(int(i / target_fps * source_fps), T - 1) for i in range(n_out)]
|
||||
return video[indices]
|
||||
|
||||
|
||||
def _resize_frames(frames, size):
|
||||
"""Resize [N,H,W,C] float32 [0,1] → [N,C,H,W] at target size."""
|
||||
x = frames.permute(0, 3, 1, 2) # [N, C, H, W]
|
||||
x = F.interpolate(x.float(), size=(size, size), mode="bicubic", align_corners=False)
|
||||
return x.clamp(0.0, 1.0) # [N, C, H, W]
|
||||
|
||||
|
||||
def _compute_mask_bbox(mask, frame_h, frame_w, margin=0.1, square=True):
|
||||
"""
|
||||
Compute a bounding box around the union of all mask frames.
|
||||
|
||||
mask: [M, H', W'] float [0,1]
|
||||
square: if True, expand bbox to a square and shift into frame bounds;
|
||||
if False, apply margin independently on each axis (rect crop).
|
||||
Returns (y0, x0, y1, x1) in pixel coords clamped to (frame_h, frame_w).
|
||||
"""
|
||||
if mask.shape[1] != frame_h or mask.shape[2] != frame_w:
|
||||
m = F.interpolate(
|
||||
mask.float().unsqueeze(1), size=(frame_h, frame_w), mode="nearest-exact"
|
||||
).squeeze(1)
|
||||
else:
|
||||
m = mask.float()
|
||||
|
||||
union = (m > 0.5).max(dim=0).values # [H, W] bool
|
||||
|
||||
if not union.any():
|
||||
if square:
|
||||
# Empty mask — center square crop
|
||||
side = min(frame_h, frame_w)
|
||||
cy, cx = frame_h // 2, frame_w // 2
|
||||
y0 = max(0, cy - side // 2)
|
||||
x0 = max(0, cx - side // 2)
|
||||
return y0, x0, min(frame_h, y0 + side), min(frame_w, x0 + side)
|
||||
else:
|
||||
# Empty mask — return full frame (no meaningful rect to crop to)
|
||||
return 0, 0, frame_h, frame_w
|
||||
|
||||
ys = union.any(dim=1).nonzero(as_tuple=True)[0]
|
||||
xs = union.any(dim=0).nonzero(as_tuple=True)[0]
|
||||
y0, y1 = int(ys[0]), int(ys[-1]) + 1
|
||||
x0, x1 = int(xs[0]), int(xs[-1]) + 1
|
||||
|
||||
if square:
|
||||
side = max(y1 - y0, x1 - x0)
|
||||
pad = int(side * margin)
|
||||
side += 2 * pad
|
||||
|
||||
cy = (y0 + y1) // 2
|
||||
cx = (x0 + x1) // 2
|
||||
y0n = cy - side // 2
|
||||
x0n = cx - side // 2
|
||||
y1n = y0n + side
|
||||
x1n = x0n + side
|
||||
|
||||
# Shift into frame bounds to preserve square shape
|
||||
if y0n < 0: y1n -= y0n; y0n = 0
|
||||
if y1n > frame_h: y0n -= y1n - frame_h; y1n = frame_h
|
||||
if x0n < 0: x1n -= x0n; x0n = 0
|
||||
if x1n > frame_w: x0n -= x1n - frame_w; x1n = frame_w
|
||||
|
||||
return max(0, int(y0n)), max(0, int(x0n)), min(frame_h, int(y1n)), min(frame_w, int(x1n))
|
||||
else:
|
||||
pad_y = int(max(1, y1 - y0) * margin)
|
||||
pad_x = int(max(1, x1 - x0) * margin)
|
||||
return max(0, y0 - pad_y), max(0, x0 - pad_x), min(frame_h, y1 + pad_y), min(frame_w, x1 + pad_x)
|
||||
|
||||
|
||||
def _apply_mask(frames, mask, source_fps, target_fps, mask_strength=1.0):
|
||||
"""
|
||||
Apply a ComfyUI MASK to resized frames.
|
||||
|
||||
frames: [N, C, H, W] float [0,1]
|
||||
mask: [M, H', W'] float [0,1] — M=1 static or M=T per-frame
|
||||
source_fps: original video fps (for accurate temporal sampling)
|
||||
target_fps: sampling fps of this frame set (CLIP_FPS or SYNC_FPS)
|
||||
mask_strength: 0=no effect, 1=full masking; background filled with 0.5 (neutral gray)
|
||||
|
||||
Background pixels are filled with 0.5 rather than 0 — less out-of-distribution
|
||||
for CLIP, and maps to 0 (neutral) after [-1,1] normalization on the sync path.
|
||||
"""
|
||||
N, C, H, W = frames.shape
|
||||
M = mask.shape[0]
|
||||
mask_f = mask.float().unsqueeze(1) # [M, 1, H', W']
|
||||
if mask_f.shape[2] != H or mask_f.shape[3] != W:
|
||||
mask_f = F.interpolate(mask_f, size=(H, W), mode="nearest-exact") # [M, 1, H, W]
|
||||
|
||||
# Temporal sampling — use same index formula as _sample_frames for accuracy
|
||||
if M == 1:
|
||||
mask_f = mask_f.expand(N, -1, -1, -1)
|
||||
else:
|
||||
indices = [min(int(i / target_fps * source_fps), M - 1) for i in range(N)]
|
||||
mask_f = mask_f[indices] # [N, 1, H, W]
|
||||
|
||||
mask_f = mask_f.to(frames.device)
|
||||
|
||||
# alpha=1 on foreground, (1-strength) on background → blend toward neutral gray
|
||||
alpha = 1.0 - mask_strength * (1.0 - mask_f)
|
||||
return frames * alpha + 0.5 * (1.0 - alpha)
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, prompt, fps, duration, variant, mask=None,
|
||||
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||
h = hashlib.sha256()
|
||||
raw = video_tensor.cpu().numpy().tobytes()
|
||||
n = len(raw)
|
||||
chunk = 512 * 1024 # 512 KB per sample
|
||||
h.update(raw[:chunk])
|
||||
h.update(raw[n // 2: n // 2 + chunk])
|
||||
h.update(raw[max(0, n - chunk):])
|
||||
if mask is not None:
|
||||
raw_m = mask.cpu().numpy().tobytes()
|
||||
nm = len(raw_m)
|
||||
chunk_m = 256 * 1024
|
||||
h.update(raw_m[:chunk_m])
|
||||
h.update(raw_m[nm // 2: nm // 2 + chunk_m])
|
||||
h.update(raw_m[max(0, nm - chunk_m):])
|
||||
h.update(str(round(mask_strength, 4)).encode())
|
||||
h.update(str(mask_clip).encode())
|
||||
h.update(str(mask_sync).encode())
|
||||
h.update(str(crop_to_mask).encode())
|
||||
h.update(str(crop_rect).encode())
|
||||
if crop_to_mask or crop_rect:
|
||||
h.update(str(round(crop_margin, 4)).encode())
|
||||
h.update(prompt.encode())
|
||||
h.update(str(fps).encode())
|
||||
h.update(str(round(duration, 3)).encode()) # resolved duration affects frame count
|
||||
h.update(variant.encode())
|
||||
return h.hexdigest()[:32]
|
||||
|
||||
|
||||
class SelvaFeatureExtractor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"video": ("IMAGE",),
|
||||
"prompt": ("STRING", {
|
||||
"default": "", "multiline": True,
|
||||
"tooltip": "Describes the sounds to generate. Used to focus the visual sync features on motion relevant to the prompt — more specific prompts produce cleaner audio sync. Wire the prompt output directly to the Sampler so you only type it once.",
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"video_info": ("VHS_VIDEOINFO", {
|
||||
"tooltip": "VHS_VIDEOINFO from VHS LoadVideo. Automatically sets the correct source fps — always connect this when loading video with VHS nodes.",
|
||||
}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001,
|
||||
"tooltip": "Source fps of the input video. Ignored when video_info is connected."}),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||
"tooltip": "Clip duration in seconds. 0 = use the full video length. Clamped to actual video length if too long."}),
|
||||
"cache_dir": ("STRING", {"default": "",
|
||||
"tooltip": "Where to store extracted feature files (.npz). Leave empty for the system temp directory. Reusing the same directory enables instant cache hits on re-runs."}),
|
||||
"mask": ("MASK", {
|
||||
"tooltip": "Optional segmentation mask [T,H,W] float [0,1]. Background pixels are zeroed before encoding — useful when multiple objects compete for the same sound. Static (1-frame) or per-frame masks both supported. Connect SAM2 or Grounding DINO+SAM output.",
|
||||
}),
|
||||
"mask_strength": ("FLOAT", {
|
||||
"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||
"tooltip": "How strongly to suppress the background. 1.0 = full neutral fill; 0.0 = no masking effect. Values in between blend smoothly.",
|
||||
}),
|
||||
"mask_clip": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply the mask to CLIP visual features (384px). Disable if you want CLIP to see the full scene context while sync features stay focused.",
|
||||
}),
|
||||
"mask_sync": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Apply the mask to TextSynchformer sync features (224px). This is the primary path for isolating which object's motion drives the audio.",
|
||||
}),
|
||||
"crop_to_mask": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Experimental. Crops frames to a square region around the mask bounding box before resizing. The model sees an undistorted view of the subject. Requires mask. Takes priority over crop_rect.",
|
||||
}),
|
||||
"crop_rect": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Experimental. Crops frames to a rectangle around the mask bounding box (with margin) before resizing. The model still stretches the crop to a square, but only sees the region around the target element. Simpler than crop_to_mask. Requires mask.",
|
||||
}),
|
||||
"crop_margin": ("FLOAT", {
|
||||
"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.05,
|
||||
"tooltip": "Margin added around the bounding box as a fraction of the bbox size. Shared by crop_to_mask and crop_rect. 0.1 = 10% on each side.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
||||
RETURN_NAMES = ("features", "fps", "prompt")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"Extracted feature bundle — connect to Sampler.",
|
||||
"Source fps of the video — wire to VHS_VideoCombine frame_rate.",
|
||||
"The prompt used during extraction — wire to Sampler prompt to avoid re-typing.",
|
||||
)
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Extracts CLIP visual features and text-conditioned sync features from a video. Results are cached — re-running with the same inputs is instant."
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir="", mask=None,
|
||||
mask_strength=1.0, mask_clip=True, mask_sync=True,
|
||||
crop_to_mask=False, crop_rect=False, crop_margin=0.1):
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
T = video.shape[0]
|
||||
if duration <= 0:
|
||||
duration = T / fps
|
||||
duration = min(duration, T / fps) # clamp to actual video length
|
||||
|
||||
if not prompt.strip():
|
||||
print("[SelVA] Warning: empty prompt — TextSynchformer sync features will be unfocused.", flush=True)
|
||||
|
||||
# Cache
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "selva_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_key = _hash_inputs(video, prompt, fps, duration, model["variant"], mask=mask,
|
||||
mask_strength=mask_strength, mask_clip=mask_clip, mask_sync=mask_sync,
|
||||
crop_to_mask=crop_to_mask, crop_rect=crop_rect, crop_margin=crop_margin)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_key}.npz")
|
||||
|
||||
if os.path.exists(cached_path):
|
||||
print(f"[SelVA] Using cached features: {cached_path}", flush=True)
|
||||
cached = _load_cached(cached_path)
|
||||
return (cached, float(fps), cached.get("prompt", prompt))
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
feature_utils = model["feature_utils"]
|
||||
net_video_enc = model["video_enc"]
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(device)
|
||||
net_video_enc.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
print(f"[SelVA] Extracting features: duration={duration:.2f}s fps={fps:.3f} prompt='{prompt[:60]}'", flush=True)
|
||||
pbar = comfy.utils.ProgressBar(3)
|
||||
|
||||
# Pre-compute crop bbox once from the original-resolution mask
|
||||
crop_bbox = None
|
||||
if mask is not None and (crop_to_mask or crop_rect):
|
||||
H_vid, W_vid = video.shape[1], video.shape[2]
|
||||
_square = crop_to_mask # crop_to_mask takes priority; crop_rect is rect-only
|
||||
crop_bbox = _compute_mask_bbox(mask, H_vid, W_vid, crop_margin, square=_square)
|
||||
cy0, cx0, cy1, cx1 = crop_bbox
|
||||
_mode = "square" if _square else "rect"
|
||||
print(f"[SelVA] Mask crop ({_mode}): y={cy0}:{cy1} x={cx0}:{cx1} "
|
||||
f"({cy1-cy0}×{cx1-cx0}px from {H_vid}×{W_vid})", flush=True)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# --- CLIP frames: [1, N, C, 384, 384] float32 [0,1] ---
|
||||
clip_frames = _sample_frames(video, fps, _CLIP_FPS, duration) # [N, H, W, C]
|
||||
if crop_bbox is not None:
|
||||
cy0, cx0, cy1, cx1 = crop_bbox
|
||||
clip_frames = clip_frames[:, cy0:cy1, cx0:cx1, :]
|
||||
clip_frames = _resize_frames(clip_frames, _CLIP_SIZE) # [N, C, 384, 384]
|
||||
if mask is not None and mask_clip:
|
||||
clip_frames = _apply_mask(clip_frames, mask, fps, _CLIP_FPS, mask_strength)
|
||||
clip_input = clip_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 384, 384]
|
||||
_clip_tag = f"(masked strength={mask_strength})" if mask is not None and mask_clip else ("(mask skipped)" if mask is not None else "")
|
||||
print(f"[SelVA] CLIP frames: {clip_frames.shape[0]} @ {_CLIP_FPS}fps → 384px {_clip_tag}", flush=True)
|
||||
|
||||
clip_features = feature_utils.encode_video_with_clip(clip_input) # [1, N, 1024]
|
||||
pbar.update(1)
|
||||
|
||||
# --- Sync frames: [1, N, C, 224, 224] float32 [-1,1] ---
|
||||
sync_frames = _sample_frames(video, fps, _SYNC_FPS, duration) # [N, H, W, C]
|
||||
if crop_bbox is not None:
|
||||
cy0, cx0, cy1, cx1 = crop_bbox
|
||||
sync_frames = sync_frames[:, cy0:cy1, cx0:cx1, :]
|
||||
sync_frames = _resize_frames(sync_frames, _SYNC_SIZE) # [N, C, 224, 224]
|
||||
if mask is not None and mask_sync:
|
||||
sync_frames = _apply_mask(sync_frames, mask, fps, _SYNC_FPS, mask_strength)
|
||||
# Pad to minimum 16 frames (TextSynchformer segment size)
|
||||
if sync_frames.shape[0] < 16:
|
||||
pad = 16 - sync_frames.shape[0]
|
||||
sync_frames = torch.cat([sync_frames, sync_frames[-1:].expand(pad, -1, -1, -1)], dim=0)
|
||||
# Normalize [0,1] → [-1,1]
|
||||
mean = _SYNC_MEAN.to(sync_frames.device)
|
||||
std = _SYNC_STD.to(sync_frames.device)
|
||||
sync_frames = (sync_frames - mean) / std
|
||||
sync_input = sync_frames.unsqueeze(0).to(device, dtype) # [1, N, C, 224, 224]
|
||||
_sync_tag = f"(masked strength={mask_strength})" if mask is not None and mask_sync else ("(mask skipped)" if mask is not None else "")
|
||||
print(f"[SelVA] Sync frames: {sync_frames.shape[0]} @ {_SYNC_FPS}fps → 224px {_sync_tag}", flush=True)
|
||||
|
||||
# Encode T5 text + prepend supplementary tokens → text-conditioned sync features
|
||||
text_f, text_mask = feature_utils.encode_text_t5([prompt]) # [1, L, D], [1, L]
|
||||
pbar.update(1)
|
||||
text_f, text_mask = net_video_enc.prepend_sup_text_tokens(text_f, text_mask)
|
||||
sync_features = net_video_enc.encode_video_with_sync(
|
||||
sync_input, text_f=text_f, text_mask=text_mask
|
||||
) # [1, T_sync, 768]
|
||||
pbar.update(1)
|
||||
|
||||
print(f"[SelVA] clip_features: {tuple(clip_features.shape)}", flush=True)
|
||||
print(f"[SelVA] sync_features: {tuple(sync_features.shape)}", flush=True)
|
||||
|
||||
finally:
|
||||
if strategy == "offload_to_cpu":
|
||||
feature_utils.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
np.savez(
|
||||
cached_path,
|
||||
clip_features=clip_features.cpu().float().numpy(),
|
||||
sync_features=sync_features.cpu().float().numpy(),
|
||||
duration=float(duration),
|
||||
prompt=np.array(prompt),
|
||||
variant=np.array(model["variant"]),
|
||||
)
|
||||
print(f"[SelVA] Features cached: {cached_path}", flush=True)
|
||||
|
||||
return ({
|
||||
"clip_features": clip_features.cpu(),
|
||||
"sync_features": sync_features.cpu(),
|
||||
"duration": float(duration),
|
||||
"prompt": prompt,
|
||||
"variant": model["variant"],
|
||||
}, float(fps), prompt)
|
||||
|
||||
|
||||
def _load_cached(path):
|
||||
data = np.load(path, allow_pickle=False)
|
||||
features = {
|
||||
"clip_features": torch.from_numpy(data["clip_features"]),
|
||||
"sync_features": torch.from_numpy(data["sync_features"]),
|
||||
"duration": float(data["duration"]),
|
||||
}
|
||||
if "prompt" in data:
|
||||
features["prompt"] = str(data["prompt"])
|
||||
if "variant" in data:
|
||||
features["variant"] = str(data["variant"])
|
||||
return features
|
||||
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
|
||||
# Variant → (generator filename, mode, has_bigvgan)
|
||||
_VARIANTS = {
|
||||
"small_16k": ("generator_small_16k_sup_5.pth", "16k", True),
|
||||
"small_44k": ("generator_small_44k_sup_5.pth", "44k", False),
|
||||
"medium_44k": ("generator_medium_44k_sup_5.pth", "44k", False),
|
||||
"large_44k": ("generator_large_44k_sup_5.pth", "44k", False),
|
||||
}
|
||||
|
||||
_SELVA_DIR = Path(folder_paths.models_dir) / "selva"
|
||||
_PRISMAUDIO_DIR = Path(folder_paths.models_dir) / "prismaudio"
|
||||
|
||||
|
||||
_HF_REPO = "jnwnlee/SelVA"
|
||||
|
||||
# filename → (hf_repo_path, expected_md5 or None to skip check)
|
||||
# Note: 44k generators are named 44khz in the HF repo; md5=None since the
|
||||
# original download_utils had the wrong filenames so those md5s are unverified.
|
||||
_WEIGHTS = {
|
||||
"video_enc_sup_5.pth": ("weights/video_enc_sup_5.pth", "ff09a6dc36148536ee4db97eba081d05"),
|
||||
"generator_small_16k_sup_5.pth": ("weights/generator_small_16k_sup_5.pth", "1cb0f0deec52de37f67b1fd9965337d0"),
|
||||
"generator_small_44k_sup_5.pth": ("weights/generator_small_44khz_sup_5.pth", None),
|
||||
"generator_medium_44k_sup_5.pth":("weights/generator_medium_44khz_sup_5.pth", None),
|
||||
"generator_large_44k_sup_5.pth": ("weights/generator_large_44khz_sup_5.pth", None),
|
||||
"v1-16.pth": ("ext_weights/v1-16.pth", "69f56803f59a549a1a507c93859fd4d7"),
|
||||
"v1-44.pth": ("ext_weights/v1-44.pth", "fab020275fa44c6589820ce025191600"),
|
||||
"best_netG.pt": ("ext_weights/best_netG.pt", "eeaf372a38a9c31c362120aba2dde292"),
|
||||
"synchformer_state_dict.pth": ("ext_weights/synchformer_state_dict.pth", "5b2f5594b0730f70e41e549b7c94390c"),
|
||||
}
|
||||
|
||||
|
||||
def _md5(path):
|
||||
import hashlib
|
||||
h = hashlib.md5()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8 * 1024 * 1024), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _ensure(filename, subdir=None):
|
||||
"""Return path to weight file. Re-downloads if missing or MD5 mismatch."""
|
||||
import shutil
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
dest_dir = _SELVA_DIR / subdir if subdir else _SELVA_DIR
|
||||
dest_path = dest_dir / filename
|
||||
|
||||
entry = _WEIGHTS.get(filename)
|
||||
if entry is None:
|
||||
raise ValueError(f"[SelVA] Unknown weight file: {filename}")
|
||||
repo_path, expected_md5 = entry
|
||||
|
||||
if dest_path.exists():
|
||||
if expected_md5 is None:
|
||||
return str(dest_path)
|
||||
actual = _md5(dest_path)
|
||||
if actual == expected_md5:
|
||||
return str(dest_path)
|
||||
print(f"[SelVA] {filename}: MD5 mismatch ({actual} ≠ {expected_md5}), re-downloading...", flush=True)
|
||||
dest_path.unlink()
|
||||
|
||||
print(f"[SelVA] Downloading {filename} from {_HF_REPO}...", flush=True)
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
cached = hf_hub_download(repo_id=_HF_REPO, filename=repo_path)
|
||||
shutil.copy2(cached, dest_path)
|
||||
print(f"[SelVA] Saved to {dest_path}", flush=True)
|
||||
return str(dest_path)
|
||||
|
||||
|
||||
def _synchformer_path():
|
||||
"""Return synchformer path, reusing models/prismaudio/ if already present."""
|
||||
prismaudio_path = _PRISMAUDIO_DIR / "synchformer_state_dict.pth"
|
||||
if prismaudio_path.exists():
|
||||
return str(prismaudio_path)
|
||||
return _ensure("synchformer_state_dict.pth")
|
||||
|
||||
|
||||
class SelvaModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"variant": (list(_VARIANTS.keys()), {
|
||||
"tooltip": "Model size and output sample rate. small_16k is fastest (16 kHz). 44k variants output 44.1 kHz. larger = better quality, more VRAM.",
|
||||
}),
|
||||
"precision": (["bf16", "fp16", "fp32"], {
|
||||
"tooltip": "Compute dtype. bf16 is recommended on Ampere+ GPUs. fp16 for older NVIDIA hardware. fp32 if you see NaN outputs.",
|
||||
}),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"], {
|
||||
"tooltip": "auto picks keep_in_vram if ≥16 GB VRAM is free, otherwise offload_to_cpu. offload_to_cpu moves weights to RAM between nodes, saving VRAM at the cost of speed.",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
OUTPUT_TOOLTIPS = ("Loaded model bundle — connect to Feature Extractor and Sampler.",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Loads the SelVA generator, TextSynchformer encoder, CLIP, T5, and VAE. Weights are auto-downloaded from HuggingFace on first use."
|
||||
|
||||
def load_model(self, variant, precision, offload_strategy):
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
from selva_core.model.networks_video_enc import get_my_textsynch
|
||||
from selva_core.model.utils.features_utils import FeaturesUtils
|
||||
from selva_core.model.sequence_config import CONFIG_16K, CONFIG_44K
|
||||
|
||||
gen_filename, mode, has_bigvgan = _VARIANTS[variant]
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if precision == "bf16" and device.type == "cuda" and not torch.cuda.is_bf16_supported():
|
||||
print("[SelVA] Warning: bf16 not supported on this GPU — falling back to fp16.", flush=True)
|
||||
precision = "fp16"
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
|
||||
print("[SelVA] Resolving weights (auto-downloading if missing)...", flush=True)
|
||||
video_enc_path = _ensure("video_enc_sup_5.pth")
|
||||
gen_path = _ensure(gen_filename)
|
||||
vae_name = "v1-16.pth" if mode == "16k" else "v1-44.pth"
|
||||
vae_path = _ensure(vae_name, subdir="ext")
|
||||
synch_path = _synchformer_path()
|
||||
bigvgan_path = _ensure("best_netG.pt", subdir="ext") if has_bigvgan else None
|
||||
|
||||
print(f"[SelVA] Loading TextSynch from {video_enc_path}", flush=True)
|
||||
net_video_enc = get_my_textsynch("depth1").to(device, dtype).eval()
|
||||
net_video_enc.load_weights(
|
||||
torch.load(video_enc_path, map_location="cpu", weights_only=False)
|
||||
)
|
||||
|
||||
print(f"[SelVA] Loading MMAudio ({variant}) from {gen_path}", flush=True)
|
||||
seq_cfg = CONFIG_16K if mode == "16k" else CONFIG_44K
|
||||
net_generator = get_my_mmaudio(variant).to(device, dtype).eval()
|
||||
net_generator.load_weights(
|
||||
torch.load(gen_path, map_location="cpu", weights_only=False)
|
||||
)
|
||||
|
||||
print("[SelVA] Loading FeaturesUtils (CLIP + T5 + Synchformer + VAE)...", flush=True)
|
||||
feature_utils = FeaturesUtils(
|
||||
tod_vae_ckpt=vae_path,
|
||||
synchformer_ckpt=synch_path,
|
||||
enable_conditions=True,
|
||||
mode=mode,
|
||||
bigvgan_vocoder_ckpt=bigvgan_path,
|
||||
need_vae_encoder=False,
|
||||
).to(device, dtype).eval()
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
net_video_enc.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
|
||||
print(f"[SelVA] Model ready: variant={variant} dtype={dtype} strategy={strategy}", flush=True)
|
||||
|
||||
return ({
|
||||
"generator": net_generator,
|
||||
"video_enc": net_video_enc,
|
||||
"feature_utils": feature_utils,
|
||||
"variant": variant,
|
||||
"mode": mode,
|
||||
"strategy": strategy,
|
||||
"dtype": dtype,
|
||||
"seq_cfg": seq_cfg,
|
||||
},)
|
||||
@@ -0,0 +1,175 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
|
||||
class SelvaSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("SELVA_MODEL",),
|
||||
"features": ("SELVA_FEATURES",),
|
||||
"prompt": ("STRING", {
|
||||
"default": "", "multiline": True,
|
||||
"tooltip": "Sound description for CLIP text conditioning. Leave empty to reuse the prompt from the Feature Extractor (wire its prompt output here). Changing this without re-extracting features shifts CLIP conditioning but not sync features.",
|
||||
}),
|
||||
"negative_prompt": ("STRING", {
|
||||
"default": "", "multiline": False,
|
||||
"tooltip": "Sounds to suppress, e.g. 'speech, music, wind noise'. Steered away from via CFG. Leave empty for unconditional guidance baseline.",
|
||||
}),
|
||||
"duration": ("FLOAT", {
|
||||
"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1,
|
||||
"tooltip": "Output audio length in seconds. 0 = match the video duration stored in features.",
|
||||
}),
|
||||
"steps": ("INT", {"default": 25, "min": 1, "max": 200,
|
||||
"tooltip": "Euler steps for the flow matching ODE. 25 is the SelVA default. Diminishing returns above 50; below 10 may sound rough."}),
|
||||
"cfg_strength": ("FLOAT", {"default": 4.5, "min": 1.0, "max": 20.0, "step": 0.1,
|
||||
"tooltip": "Classifier-free guidance scale. Higher values follow the prompt more strictly but can introduce artifacts. SelVA default is 4.5; useful range is roughly 3–7."}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
"optional": {
|
||||
"normalize": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Peak-normalize output to [-1, 1]. Disable to preserve the raw decoder output level.",
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
OUTPUT_TOOLTIPS = ("Generated audio waveform — connect to VHS_VideoCombine or Save Audio.",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
DESCRIPTION = "Generates audio from video features using SelVA's flow matching ODE. Supports text prompts and negative prompts via classifier-free guidance."
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed, normalize=True):
|
||||
import dataclasses
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
net_generator = model["generator"]
|
||||
feature_utils = model["feature_utils"]
|
||||
|
||||
# Validate that features were extracted with the same model variant
|
||||
feat_variant = features.get("variant")
|
||||
if feat_variant is not None and feat_variant != model["variant"]:
|
||||
raise ValueError(
|
||||
f"[SelVA] Variant mismatch: features were extracted with '{feat_variant}' "
|
||||
f"but model is '{model['variant']}'. Re-run the Feature Extractor with the current model."
|
||||
)
|
||||
|
||||
# Resolve prompt: use override if given, otherwise fall back to features prompt
|
||||
if not prompt or not prompt.strip():
|
||||
prompt = features.get("prompt", "")
|
||||
if prompt:
|
||||
print(f"[SelVA] Using prompt from features: '{prompt[:60]}'", flush=True)
|
||||
else:
|
||||
print("[SelVA] Warning: no prompt in features or sampler — CLIP text conditioning will be empty.", flush=True)
|
||||
|
||||
# Resolve duration
|
||||
if duration <= 0:
|
||||
if "duration" not in features:
|
||||
raise ValueError("[SelVA] duration=0 but features contain no duration field.")
|
||||
duration = features["duration"]
|
||||
print(f"[SelVA] Using video duration from features: {duration:.2f}s", flush=True)
|
||||
|
||||
# Derive sequence config for this duration from the model's mode template
|
||||
seq_cfg = dataclasses.replace(model["seq_cfg"], duration=duration)
|
||||
sample_rate = seq_cfg.sampling_rate
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(device)
|
||||
feature_utils.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
try:
|
||||
clip_f = features["clip_features"].to(device, dtype) # [1, T_clip, 1024]
|
||||
sync_f = features["sync_features"].to(device, dtype) # [1, T_sync, 768]
|
||||
|
||||
print(f"[SelVA] clip_f={tuple(clip_f.shape)} sync_f={tuple(sync_f.shape)}", flush=True)
|
||||
|
||||
# Update model rotary position embeddings for actual feature shapes and duration.
|
||||
# Use actual feature dimensions (not seq_cfg) to avoid rounding assertion mismatches.
|
||||
net_generator.update_seq_lengths(
|
||||
latent_seq_len=seq_cfg.latent_seq_len,
|
||||
clip_seq_len=clip_f.shape[1],
|
||||
sync_seq_len=sync_f.shape[1],
|
||||
)
|
||||
print(f"[SelVA] seq: latent={seq_cfg.latent_seq_len} clip={clip_f.shape[1]} sync={sync_f.shape[1]}", flush=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# Encode text conditioning
|
||||
text_clip = feature_utils.encode_text_clip([prompt]) # [1, 77, D]
|
||||
|
||||
# Encode negative prompt (or use empty conditions)
|
||||
neg_text_clip = feature_utils.encode_text_clip([negative_prompt]) \
|
||||
if negative_prompt.strip() else None
|
||||
|
||||
conditions = net_generator.preprocess_conditions(clip_f, sync_f, text_clip)
|
||||
empty_conditions = net_generator.get_empty_conditions(
|
||||
bs=1, negative_text_features=neg_text_clip
|
||||
)
|
||||
|
||||
# Initial noise (MPS doesn't support torch.Generator on device)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
rng = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
x0 = torch.randn(
|
||||
1, seq_cfg.latent_seq_len, net_generator.latent_dim,
|
||||
device=gen_device, dtype=dtype, generator=rng,
|
||||
).to(device)
|
||||
|
||||
# Flow matching ODE (Euler)
|
||||
fm = FlowMatching(min_sigma=0, inference_mode="euler", num_steps=steps)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
def ode_wrapper_tracked(t, x):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
pbar.update(1)
|
||||
return net_generator.ode_wrapper(t, x, conditions, empty_conditions, cfg_strength)
|
||||
|
||||
try:
|
||||
x1 = fm.to_data(ode_wrapper_tracked, x0)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during generation. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
print(f"[SelVA] latent stats: mean={x1.float().mean():.4f} std={x1.float().std():.4f}", flush=True)
|
||||
|
||||
# Decode: latent → mel → audio
|
||||
try:
|
||||
with torch.no_grad():
|
||||
x1_unnorm = net_generator.unnormalize(x1)
|
||||
spec = feature_utils.decode(x1_unnorm) # latent → mel spectrogram
|
||||
audio = feature_utils.vocode(spec) # mel → waveform
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
raise RuntimeError(
|
||||
"[SelVA] CUDA out of memory during decode/vocode. Try switching offload_strategy "
|
||||
"to 'offload_to_cpu', using a smaller variant, or reducing duration."
|
||||
)
|
||||
|
||||
finally:
|
||||
if strategy == "offload_to_cpu":
|
||||
net_generator.to(get_offload_device())
|
||||
feature_utils.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Ensure [1, 1, samples] and normalize to [-1,1]
|
||||
audio = audio.float()
|
||||
if audio.dim() == 2:
|
||||
audio = audio.unsqueeze(1)
|
||||
elif audio.dim() == 3 and audio.shape[1] != 1:
|
||||
audio = audio.mean(dim=1, keepdim=True) # stereo → mono
|
||||
|
||||
if normalize:
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[SelVA] audio: shape={tuple(audio.shape)} sr={sample_rate}", flush=True)
|
||||
|
||||
return ({"waveform": audio.cpu(), "sample_rate": sample_rate},)
|
||||
@@ -1,160 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
from .sampler import _substitute_empty_features
|
||||
|
||||
|
||||
class PrismAudioTextOnly:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
||||
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Encode text with T5-Gemma
|
||||
text_features = _encode_text_t5(text_prompt, device, dtype)
|
||||
|
||||
# Build metadata: tuple of one dict per sample
|
||||
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
||||
# These will be substituted with learned empty embeddings after conditioning
|
||||
sample_meta = {
|
||||
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
||||
"text_features": text_features.to(device, dtype=dtype),
|
||||
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
||||
"video_exist": torch.tensor(False),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Substitute empty features for video/sync
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
||||
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
# T5-Gemma encoder singleton
|
||||
_t5_model = None
|
||||
_t5_tokenizer = None
|
||||
|
||||
|
||||
def _encode_text_t5(text, device, dtype):
|
||||
"""Encode text using T5-Gemma.
|
||||
|
||||
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
||||
FeaturesUtils.encode_t5_text() implementation.
|
||||
No truncation applied (matching reference behavior).
|
||||
"""
|
||||
global _t5_model, _t5_tokenizer
|
||||
|
||||
if _t5_model is None:
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
model_id = "google/t5gemma-l-l-ul2-it"
|
||||
token = resolve_hf_token()
|
||||
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
||||
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
||||
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
||||
_t5_model.eval()
|
||||
|
||||
_t5_model.to(device, dtype=dtype)
|
||||
|
||||
tokens = _t5_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = _t5_model(**tokens)
|
||||
|
||||
# Move T5 off GPU after encoding to save VRAM
|
||||
_t5_model.to("cpu")
|
||||
soft_empty_cache()
|
||||
|
||||
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
||||
+3
-46
@@ -1,21 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
|
||||
PRISMAUDIO_CATEGORY = "PrismAudio"
|
||||
SAMPLE_RATE = 44100
|
||||
DOWNSAMPLING_RATIO = 2048
|
||||
IO_CHANNELS = 64
|
||||
|
||||
def get_prismaudio_model_dir():
|
||||
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
def register_model_folder():
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
||||
SELVA_CATEGORY = "SelVA"
|
||||
|
||||
def get_device():
|
||||
return mm.get_torch_device()
|
||||
@@ -23,42 +9,13 @@ def get_device():
|
||||
def get_offload_device():
|
||||
return mm.unet_offload_device()
|
||||
|
||||
def get_free_memory(device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
return mm.get_free_memory(device)
|
||||
|
||||
def soft_empty_cache():
|
||||
mm.soft_empty_cache()
|
||||
|
||||
def determine_precision(preference, device):
|
||||
if preference != "auto":
|
||||
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
||||
if device.type == "cpu":
|
||||
return torch.float32
|
||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
return torch.bfloat16
|
||||
return torch.float16
|
||||
|
||||
def determine_offload_strategy(preference):
|
||||
if preference != "auto":
|
||||
return preference
|
||||
free_mem = get_free_memory()
|
||||
gb = free_mem / (1024 ** 3)
|
||||
if gb >= 24:
|
||||
free_mem = mm.get_free_memory(get_device())
|
||||
if free_mem / (1024 ** 3) >= 16:
|
||||
return "keep_in_vram"
|
||||
else:
|
||||
return "offload_to_cpu"
|
||||
|
||||
def try_import_flash_attn():
|
||||
try:
|
||||
import flash_attn
|
||||
return flash_attn
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
def resolve_hf_token():
|
||||
env_token = os.environ.get("HF_TOKEN")
|
||||
if env_token:
|
||||
return env_token
|
||||
return None
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
PrismAudio core inference modules.
|
||||
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
|
||||
Only inference-critical code — no training, no JAX/TF dependencies.
|
||||
"""
|
||||
@@ -1,141 +0,0 @@
|
||||
{
|
||||
"model_type": "diffusion_cond",
|
||||
"sample_size": 397312,
|
||||
"sample_rate": 44100,
|
||||
"audio_channels": 2,
|
||||
"model": {
|
||||
"pretransform": {
|
||||
"type": "autoencoder",
|
||||
"iterate_batch": true,
|
||||
"config": {
|
||||
"encoder": {
|
||||
"type": "oobleck",
|
||||
"config": {
|
||||
"in_channels": 2,
|
||||
"channels": 128,
|
||||
"c_mults": [1, 2, 4, 8, 16],
|
||||
"strides": [2, 4, 4, 8, 8],
|
||||
"latent_dim": 128,
|
||||
"use_snake": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"type": "oobleck",
|
||||
"config": {
|
||||
"out_channels": 2,
|
||||
"channels": 128,
|
||||
"c_mults": [1, 2, 4, 8, 16],
|
||||
"strides": [2, 4, 4, 8, 8],
|
||||
"latent_dim": 64,
|
||||
"use_snake": true,
|
||||
"final_tanh": false
|
||||
}
|
||||
},
|
||||
"bottleneck": {
|
||||
"type": "vae"
|
||||
},
|
||||
"latent_dim": 64,
|
||||
"downsampling_ratio": 2048,
|
||||
"io_channels": 2
|
||||
}
|
||||
},
|
||||
"conditioning": {
|
||||
"configs": [
|
||||
{
|
||||
"id": "video_features",
|
||||
"type": "cond_mlp",
|
||||
"config": {
|
||||
"dim": 1024,
|
||||
"output_dim": 1024
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "text_features",
|
||||
"type": "cond_mlp",
|
||||
"config": {
|
||||
"dim": 1024,
|
||||
"output_dim": 1024
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "sync_features",
|
||||
"type": "sync_mlp",
|
||||
"config": {
|
||||
"dim": 768,
|
||||
"output_dim": 1024
|
||||
}
|
||||
}
|
||||
],
|
||||
"cond_dim": 768
|
||||
},
|
||||
"diffusion": {
|
||||
"cross_attention_cond_ids": ["video_features","text_features"],
|
||||
"add_cond_ids": ["video_features"],
|
||||
"sync_cond_ids": ["sync_features"],
|
||||
"type": "dit",
|
||||
"diffusion_objective": "rectified_flow",
|
||||
"config": {
|
||||
"io_channels": 64,
|
||||
"embed_dim": 1024,
|
||||
"depth": 24,
|
||||
"num_heads": 16,
|
||||
"cond_token_dim": 1024,
|
||||
"add_token_dim": 1024,
|
||||
"sync_token_dim": 1024,
|
||||
"project_cond_tokens": false,
|
||||
"transformer_type": "continuous_transformer",
|
||||
"attn_kwargs":{
|
||||
"qk_norm": "rns"
|
||||
},
|
||||
"use_gated": true,
|
||||
"use_sync_gated": true
|
||||
}
|
||||
},
|
||||
"io_channels": 64
|
||||
},
|
||||
"training": {
|
||||
"use_ema": true,
|
||||
"log_loss_info": false,
|
||||
"cfg_dropout_prob": 0.1,
|
||||
"pre_encoded": true,
|
||||
"timestep_sampler": "trunc_logit_normal",
|
||||
"optimizer_configs": {
|
||||
"diffusion": {
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"config": {
|
||||
"lr": 1e-4,
|
||||
"betas": [0.9, 0.999],
|
||||
"weight_decay": 1e-3
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "InverseLR",
|
||||
"config": {
|
||||
"inv_gamma": 100000,
|
||||
"power": 0.5,
|
||||
"warmup": 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"demo": {
|
||||
"demo_every": 5000,
|
||||
"demo_steps": 24,
|
||||
"num_demos": 10,
|
||||
"demo_cond": [
|
||||
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
|
||||
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
|
||||
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
|
||||
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
|
||||
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
|
||||
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
|
||||
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
|
||||
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
|
||||
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
|
||||
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
|
||||
],
|
||||
"demo_cfg_scales": [5]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,413 +0,0 @@
|
||||
"""
|
||||
Model factory functions for PrismAudio inference.
|
||||
|
||||
Extracted from:
|
||||
- PrismAudio/models/factory.py
|
||||
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
|
||||
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
|
||||
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
|
||||
|
||||
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
|
||||
Only inference-critical factory functions are retained.
|
||||
"""
|
||||
|
||||
import json
|
||||
import typing as tp
|
||||
from typing import Dict, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_model_from_config(model_config):
|
||||
model_type = model_config.get('model_type', None)
|
||||
|
||||
assert model_type is not None, 'model_type must be specified in model config'
|
||||
|
||||
if model_type == 'autoencoder':
|
||||
return create_autoencoder_from_config(model_config)
|
||||
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
|
||||
return create_diffusion_cond_from_config(model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
|
||||
def create_pretransform_from_config(pretransform_config, sample_rate):
|
||||
pretransform_type = pretransform_config.get('type', None)
|
||||
|
||||
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
||||
|
||||
if pretransform_type == 'autoencoder':
|
||||
from prismaudio_core.models.pretransforms import AutoencoderPretransform
|
||||
|
||||
# Create fake top-level config to pass sample rate to autoencoder constructor
|
||||
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
||||
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
||||
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
||||
|
||||
scale = pretransform_config.get("scale", 1.0)
|
||||
model_half = pretransform_config.get("model_half", False)
|
||||
iterate_batch = pretransform_config.get("iterate_batch", False)
|
||||
chunked = pretransform_config.get("chunked", False)
|
||||
|
||||
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
||||
elif pretransform_type == 'wavelet':
|
||||
raise NotImplementedError("wavelet pretransform type is not supported")
|
||||
elif pretransform_type == 'pqmf':
|
||||
from prismaudio_core.models.pretransforms import PQMFPretransform
|
||||
pqmf_config = pretransform_config["config"]
|
||||
pretransform = PQMFPretransform(**pqmf_config)
|
||||
elif pretransform_type == 'dac_pretrained':
|
||||
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
|
||||
pretrained_dac_config = pretransform_config["config"]
|
||||
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
||||
elif pretransform_type == "audiocraft_pretrained":
|
||||
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
|
||||
|
||||
audiocraft_config = pretransform_config["config"]
|
||||
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
||||
|
||||
enable_grad = pretransform_config.get('enable_grad', False)
|
||||
pretransform.enable_grad = enable_grad
|
||||
|
||||
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
||||
|
||||
return pretransform
|
||||
|
||||
|
||||
def create_bottleneck_from_config(bottleneck_config):
|
||||
bottleneck_type = bottleneck_config.get('type', None)
|
||||
|
||||
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
||||
|
||||
if bottleneck_type == 'tanh':
|
||||
from prismaudio_core.models.bottleneck import TanhBottleneck
|
||||
bottleneck = TanhBottleneck()
|
||||
elif bottleneck_type == 'vae':
|
||||
from prismaudio_core.models.bottleneck import VAEBottleneck
|
||||
bottleneck = VAEBottleneck()
|
||||
elif bottleneck_type == 'rvq':
|
||||
from prismaudio_core.models.bottleneck import RVQBottleneck
|
||||
|
||||
quantizer_params = {
|
||||
"dim": 128,
|
||||
"codebook_size": 1024,
|
||||
"num_quantizers": 8,
|
||||
"decay": 0.99,
|
||||
"kmeans_init": True,
|
||||
"kmeans_iters": 50,
|
||||
"threshold_ema_dead_code": 2,
|
||||
}
|
||||
|
||||
quantizer_params.update(bottleneck_config["config"])
|
||||
|
||||
bottleneck = RVQBottleneck(**quantizer_params)
|
||||
elif bottleneck_type == "dac_rvq":
|
||||
from prismaudio_core.models.bottleneck import DACRVQBottleneck
|
||||
|
||||
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
||||
|
||||
elif bottleneck_type == 'rvq_vae':
|
||||
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
|
||||
|
||||
quantizer_params = {
|
||||
"dim": 128,
|
||||
"codebook_size": 1024,
|
||||
"num_quantizers": 8,
|
||||
"decay": 0.99,
|
||||
"kmeans_init": True,
|
||||
"kmeans_iters": 50,
|
||||
"threshold_ema_dead_code": 2,
|
||||
}
|
||||
|
||||
quantizer_params.update(bottleneck_config["config"])
|
||||
|
||||
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
||||
|
||||
elif bottleneck_type == 'dac_rvq_vae':
|
||||
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
|
||||
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
||||
elif bottleneck_type == 'l2_norm':
|
||||
from prismaudio_core.models.bottleneck import L2Bottleneck
|
||||
bottleneck = L2Bottleneck()
|
||||
elif bottleneck_type == "wasserstein":
|
||||
from prismaudio_core.models.bottleneck import WassersteinBottleneck
|
||||
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
||||
elif bottleneck_type == "fsq":
|
||||
from prismaudio_core.models.bottleneck import FSQBottleneck
|
||||
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
||||
|
||||
requires_grad = bottleneck_config.get('requires_grad', True)
|
||||
if not requires_grad:
|
||||
for param in bottleneck.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return bottleneck
|
||||
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
"""Create an AudioAutoencoder from a config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/autoencoders.py.
|
||||
"""
|
||||
from prismaudio_core.models.autoencoders import (
|
||||
AudioAutoencoder,
|
||||
create_encoder_from_config,
|
||||
create_decoder_from_config,
|
||||
)
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
|
||||
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
|
||||
"""Create a MultiConditioner from a conditioning config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/conditioners.py.
|
||||
"""
|
||||
from prismaudio_core.models.conditioners import (
|
||||
MultiConditioner,
|
||||
T5Conditioner,
|
||||
CLAPTextConditioner,
|
||||
CLIPTextConditioner,
|
||||
MetaCLIPTextConditioner,
|
||||
CLAPAudioConditioner,
|
||||
Cond_MLP,
|
||||
Global_MLP,
|
||||
Sync_MLP,
|
||||
Cond_MLP_1,
|
||||
Cond_ConvMLP,
|
||||
Cond_MLP_Global,
|
||||
Cond_MLP_Global_1,
|
||||
Cond_MLP_Global_2,
|
||||
Video_Global,
|
||||
Video_Sync,
|
||||
Text_Linear,
|
||||
CLIPConditioner,
|
||||
IntConditioner,
|
||||
NumberConditioner,
|
||||
PhonemeConditioner,
|
||||
TokenizerLUTConditioner,
|
||||
PretransformConditioner,
|
||||
mm_unchang,
|
||||
)
|
||||
from prismaudio_core.models.utils import load_ckpt_state_dict
|
||||
|
||||
conditioners = {}
|
||||
cond_dim = config["cond_dim"]
|
||||
|
||||
default_keys = config.get("default_keys", {})
|
||||
|
||||
for conditioner_info in config["configs"]:
|
||||
id = conditioner_info["id"]
|
||||
|
||||
conditioner_type = conditioner_info["type"]
|
||||
|
||||
conditioner_config = {"output_dim": cond_dim}
|
||||
|
||||
conditioner_config.update(conditioner_info["config"])
|
||||
if conditioner_type == "t5":
|
||||
conditioners[id] = T5Conditioner(**conditioner_config)
|
||||
elif conditioner_type == "clap_text":
|
||||
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "clip_text":
|
||||
conditioners[id] = CLIPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "metaclip_text":
|
||||
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "clap_audio":
|
||||
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp":
|
||||
conditioners[id] = Cond_MLP(**conditioner_config)
|
||||
elif conditioner_type == "global_mlp":
|
||||
conditioners[id] = Global_MLP(**conditioner_config)
|
||||
elif conditioner_type == "sync_mlp":
|
||||
conditioners[id] = Sync_MLP(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_1":
|
||||
conditioners[id] = Cond_MLP_1(**conditioner_config)
|
||||
elif conditioner_type == "cond_convmlp":
|
||||
conditioners[id] = Cond_ConvMLP(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global":
|
||||
conditioners[id] = Cond_MLP_Global(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global_1":
|
||||
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global_2":
|
||||
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
|
||||
elif conditioner_type == "video_global":
|
||||
conditioners[id] = Video_Global(**conditioner_config)
|
||||
elif conditioner_type == "video_sync":
|
||||
conditioners[id] = Video_Sync(**conditioner_config)
|
||||
elif conditioner_type == "text_linear":
|
||||
conditioners[id] = Text_Linear(**conditioner_config)
|
||||
elif conditioner_type == "video_clip":
|
||||
conditioners[id] = CLIPConditioner(**conditioner_config)
|
||||
elif conditioner_type == "int":
|
||||
conditioners[id] = IntConditioner(**conditioner_config)
|
||||
elif conditioner_type == "number":
|
||||
conditioners[id] = NumberConditioner(**conditioner_config)
|
||||
elif conditioner_type == "phoneme":
|
||||
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
||||
elif conditioner_type == "lut":
|
||||
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
||||
elif conditioner_type == "pretransform":
|
||||
sample_rate = conditioner_config.pop("sample_rate", None)
|
||||
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
||||
|
||||
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
||||
|
||||
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
||||
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
||||
|
||||
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
||||
elif conditioner_type == "mm_unchang":
|
||||
conditioners[id] = mm_unchang(**conditioner_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
||||
|
||||
return MultiConditioner(conditioners, default_keys=default_keys)
|
||||
|
||||
|
||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/diffusion.py.
|
||||
"""
|
||||
from prismaudio_core.models.diffusion import (
|
||||
ConditionedDiffusionModelWrapper,
|
||||
MMConditionedDiffusionModelWrapper,
|
||||
UNetCFG1DWrapper,
|
||||
UNet1DCondWrapper,
|
||||
DiTWrapper,
|
||||
)
|
||||
|
||||
model_config = config["model"]
|
||||
|
||||
model_type = config["model_type"]
|
||||
|
||||
diffusion_config = model_config.get('diffusion', None)
|
||||
assert diffusion_config is not None, "Must specify diffusion config"
|
||||
|
||||
diffusion_model_type = diffusion_config.get('type', None)
|
||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||
|
||||
diffusion_model_config = diffusion_config.get('config', None)
|
||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||
|
||||
if diffusion_model_type == 'adp_cfg_1d':
|
||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'adp_1d':
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'mmdit':
|
||||
raise NotImplementedError("mmdit diffusion model type is not supported")
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
|
||||
sample_rate = config.get('sample_rate', None)
|
||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||
|
||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||
|
||||
conditioning_config = model_config.get('conditioning', None)
|
||||
|
||||
conditioner = None
|
||||
if conditioning_config is not None:
|
||||
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
||||
|
||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||
zero_init = diffusion_config.get('zero_init', False)
|
||||
pretransform = model_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
min_input_length *= diffusion_model.model.patch_size
|
||||
|
||||
# Get the proper wrapper class
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if model_type == "mm_diffusion_cond":
|
||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||
|
||||
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
elif model_type == "diffusion_prior":
|
||||
raise NotImplementedError("diffusion_prior model type is not supported")
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
min_input_length=min_input_length,
|
||||
sample_rate=sample_rate,
|
||||
cross_attn_cond_ids=cross_attention_ids,
|
||||
global_cond_ids=global_cond_ids,
|
||||
input_concat_ids=input_concat_ids,
|
||||
prepend_cond_ids=prepend_cond_ids,
|
||||
add_cond_ids=add_cond_ids,
|
||||
sync_cond_ids=sync_cond_ids,
|
||||
pretransform=pretransform,
|
||||
io_channels=io_channels,
|
||||
zero_init=zero_init,
|
||||
**extra_kwargs
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
from .sampling import sample_discrete_euler
|
||||
from .utils import set_audio_channels, prepare_audio
|
||||
|
||||
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
|
||||
@@ -1,29 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
||||
"""Discrete Euler sampler for rectified flow, with optional callback.
|
||||
|
||||
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
||||
Original uses tqdm internally.
|
||||
|
||||
Args:
|
||||
model: The diffusion model (DiTWrapper)
|
||||
x: Initial noise tensor [B, C, T]
|
||||
steps: Number of sampling steps
|
||||
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
||||
callback: Optional callable({"i": step, "x": current_x}) for progress
|
||||
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
||||
sync_cond, cfg_scale, batch_cfg, etc.
|
||||
"""
|
||||
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
||||
|
||||
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
||||
dt = t_next - t_curr
|
||||
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
||||
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"i": i, "x": x})
|
||||
|
||||
return x
|
||||
@@ -1,62 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio import transforms as T
|
||||
|
||||
|
||||
def set_audio_channels(audio, target_channels):
|
||||
"""Convert audio tensor to target number of channels.
|
||||
|
||||
Args:
|
||||
audio: Audio tensor of shape [B, C, T]
|
||||
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
||||
|
||||
Returns:
|
||||
Audio tensor with the target number of channels.
|
||||
"""
|
||||
if target_channels == 1:
|
||||
# Convert to mono
|
||||
audio = audio.mean(1, keepdim=True)
|
||||
elif target_channels == 2:
|
||||
# Convert to stereo
|
||||
if audio.shape[1] == 1:
|
||||
audio = audio.repeat(1, 2, 1)
|
||||
elif audio.shape[1] > 2:
|
||||
audio = audio[:, :2, :]
|
||||
return audio
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Resample, pad/trim, and convert channels of an audio tensor.
|
||||
|
||||
Args:
|
||||
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
||||
in_sr: Input sample rate
|
||||
target_sr: Target sample rate
|
||||
target_length: Target length in samples (padded or cropped)
|
||||
target_channels: Target number of channels
|
||||
device: Torch device to place the audio on
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape [B, target_channels, target_length] on device.
|
||||
"""
|
||||
audio = audio.to(device)
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
# Add batch dimension
|
||||
if audio.dim() == 1:
|
||||
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||
elif audio.dim() == 2:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
# Pad or crop to target_length
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[:, :, :target_length]
|
||||
|
||||
audio = set_audio_channels(audio, target_channels)
|
||||
|
||||
return audio
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
PrismAudio model modules for inference.
|
||||
|
||||
Re-exports create_model_from_config from the factory module.
|
||||
"""
|
||||
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
|
||||
__all__ = ["create_model_from_config"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,821 +0,0 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchaudio import transforms as T
|
||||
from alias_free_torch import Activation1d
|
||||
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||
from typing import Literal, Dict, Any
|
||||
|
||||
from .blocks import SnakeBeta
|
||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||
from .pretransforms import Pretransform
|
||||
from .utils import checkpoint
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
||||
import torchaudio.transforms as T
|
||||
import torch
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
if audio.shape[0] > target_channels:
|
||||
audio = audio[:target_channels]
|
||||
elif audio.shape[0] < target_channels:
|
||||
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
||||
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[..., :target_length]
|
||||
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
|
||||
def _lazy_create_bottleneck_from_config(bottleneck):
|
||||
from prismaudio_core.factory import create_bottleneck_from_config
|
||||
return create_bottleneck_from_config(bottleneck)
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class DACEncoderWrapper(nn.Module):
|
||||
def __init__(self, in_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Encoder as DACEncoder
|
||||
|
||||
latent_dim = kwargs.pop("latent_dim", None)
|
||||
|
||||
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||
|
||||
if in_channels != 1:
|
||||
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class DACDecoderWrapper(nn.Module):
|
||||
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Decoder as DACDecoder
|
||||
|
||||
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder,
|
||||
latent_dim,
|
||||
downsampling_ratio,
|
||||
sample_rate,
|
||||
io_channels=2,
|
||||
bottleneck: Bottleneck = None,
|
||||
pretransform: Pretransform = None,
|
||||
in_channels = None,
|
||||
out_channels = None,
|
||||
soft_clip = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsampling_ratio = downsampling_ratio
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.io_channels = io_channels
|
||||
self.in_channels = io_channels
|
||||
self.out_channels = io_channels
|
||||
|
||||
self.min_length = self.downsampling_ratio
|
||||
|
||||
if in_channels is not None:
|
||||
self.in_channels = in_channels
|
||||
|
||||
if out_channels is not None:
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.pretransform = pretransform
|
||||
|
||||
self.soft_clip = soft_clip
|
||||
|
||||
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||
|
||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||
|
||||
info = {}
|
||||
|
||||
if self.pretransform is not None and not skip_pretransform:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
|
||||
if self.encoder is not None:
|
||||
if iterate_batch:
|
||||
latents = []
|
||||
for i in range(audio.shape[0]):
|
||||
latents.append(self.encoder(audio[i:i+1]))
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.encoder(audio)
|
||||
else:
|
||||
latents = audio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||
|
||||
info.update(bottleneck_info)
|
||||
|
||||
if return_info:
|
||||
return latents, info
|
||||
|
||||
return latents
|
||||
|
||||
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||
|
||||
if self.bottleneck is not None:
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||
latents = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.decoder(latents[i:i+1]))
|
||||
decoded = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
decoded = self.decoder(latents, **kwargs)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(decoded.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(latents.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
if self.soft_clip:
|
||||
decoded = torch.tanh(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
'''
|
||||
Decode discrete tokens to audio
|
||||
Only works with discrete autoencoders
|
||||
'''
|
||||
|
||||
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||
|
||||
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
|
||||
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||
'''
|
||||
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||
If the model is mono, stereo audio will be converted to mono.
|
||||
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||
Audio will be resampled to the model's sample rate.
|
||||
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||
'''
|
||||
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||
|
||||
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||
'''
|
||||
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||
The audio in that list can be of different lengths and channels.
|
||||
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||
All audio will be resampled to the model's sample rate.
|
||||
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||
If the model is mono, all audio will be converted to mono.
|
||||
The output will be a tensor of shape (Batch x Channels x Length)
|
||||
'''
|
||||
batch_size = len(audio_list)
|
||||
if isinstance(in_sr_list, int):
|
||||
in_sr_list = [in_sr_list]*batch_size
|
||||
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||
new_audio = []
|
||||
max_length = 0
|
||||
# resample & find the max length
|
||||
for i in range(batch_size):
|
||||
audio = audio_list[i]
|
||||
in_sr = in_sr_list[i]
|
||||
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||
# batchsize 1 was given by accident. Just squeeze it.
|
||||
audio = audio.squeeze(0)
|
||||
elif len(audio.shape) == 1:
|
||||
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||
# Resample audio
|
||||
if in_sr != self.sample_rate:
|
||||
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||
audio = resample_tf(audio)
|
||||
new_audio.append(audio)
|
||||
if audio.shape[-1] > max_length:
|
||||
max_length = audio.shape[-1]
|
||||
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||
for i in range(batch_size):
|
||||
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||
# convert to tensor
|
||||
return torch.stack(new_audio)
|
||||
|
||||
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||
# and therefore you likely could use the same values with decode_audio.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Encode the entire audio in parallel
|
||||
return self.encode(audio, **kwargs)
|
||||
else:
|
||||
# CHUNKED ENCODING
|
||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
total_size = audio.shape[2] # in samples
|
||||
batch_size = audio.shape[0]
|
||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||
overlap *= samples_per_latent # converting metric in latents to samples
|
||||
hop_size = chunk_size - overlap
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = audio[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = audio[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||
# because we can encode audio of varying lengths
|
||||
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||
y_size = total_size // samples_per_latent
|
||||
# Create an empty latent, we will populate it with chunks as we encode them
|
||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# encode the chunk
|
||||
y_chunk = self.encode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size // samples_per_latent
|
||||
t_end = t_start + chunk_size // samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = overlap//samples_per_latent//2
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Decode latents to audio.
|
||||
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Decode the entire latent in parallel
|
||||
return self.decode(latents, **kwargs)
|
||||
else:
|
||||
# chunked decoding
|
||||
hop_size = chunk_size - overlap
|
||||
total_size = latents.shape[2]
|
||||
batch_size = latents.shape[0]
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = latents[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = latents[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# samples_per_latent is just the downsampling ratio
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
# Create an empty waveform, we will populate it with chunks as decode them
|
||||
y_size = total_size * samples_per_latent
|
||||
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# decode the chunk
|
||||
y_chunk = self.decode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size * samples_per_latent
|
||||
t_end = t_start + chunk_size * samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = (overlap//2) * samples_per_latent
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
|
||||
class DiffusionAutoencoder(AudioAutoencoder):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion: ConditionedDiffusionModel,
|
||||
diffusion_downsampling_ratio,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.diffusion = diffusion
|
||||
|
||||
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||
|
||||
if self.encoder is not None:
|
||||
# Shrink the initial encoder parameters to avoid saturated latents
|
||||
with torch.no_grad():
|
||||
for param in self.encoder.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def decode(self, latents, steps=100):
|
||||
|
||||
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if self.decoder is not None:
|
||||
latents = self.decoder(latents)
|
||||
|
||||
# Upsample latents to match diffusion length
|
||||
if latents.shape[2] != upsampled_length:
|
||||
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||
|
||||
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||
from prismaudio_core.inference.sampling import sample
|
||||
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
# AE factories
|
||||
|
||||
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||
encoder_type = encoder_config.get("type", None)
|
||||
assert encoder_type is not None, "Encoder type must be specified"
|
||||
|
||||
if encoder_type == "oobleck":
|
||||
encoder = OobleckEncoder(
|
||||
**encoder_config["config"]
|
||||
)
|
||||
|
||||
elif encoder_type == "seanet":
|
||||
from encodec.modules import SEANetEncoder
|
||||
seanet_encoder_config = encoder_config["config"]
|
||||
|
||||
#SEANet encoder expects strides in reverse order
|
||||
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||
encoder = SEANetEncoder(
|
||||
**seanet_encoder_config
|
||||
)
|
||||
elif encoder_type == "dac":
|
||||
dac_config = encoder_config["config"]
|
||||
|
||||
encoder = DACEncoderWrapper(**dac_config)
|
||||
elif encoder_type == "local_attn":
|
||||
from .local_attention import TransformerEncoder1D
|
||||
|
||||
local_attn_config = encoder_config["config"]
|
||||
|
||||
encoder = TransformerEncoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||
|
||||
requires_grad = encoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return encoder
|
||||
|
||||
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||
decoder_type = decoder_config.get("type", None)
|
||||
assert decoder_type is not None, "Decoder type must be specified"
|
||||
|
||||
if decoder_type == "oobleck":
|
||||
decoder = OobleckDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "seanet":
|
||||
from encodec.modules import SEANetDecoder
|
||||
|
||||
decoder = SEANetDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "dac":
|
||||
dac_config = decoder_config["config"]
|
||||
|
||||
decoder = DACDecoderWrapper(**dac_config)
|
||||
elif decoder_type == "local_attn":
|
||||
from .local_attention import TransformerDecoder1D
|
||||
|
||||
local_attn_config = decoder_config["config"]
|
||||
|
||||
decoder = TransformerDecoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||
|
||||
requires_grad = decoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return decoder
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||
|
||||
diffae_config = config["model"]
|
||||
|
||||
if "encoder" in diffae_config:
|
||||
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||
else:
|
||||
encoder = None
|
||||
|
||||
if "decoder" in diffae_config:
|
||||
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||
|
||||
latent_dim = diffae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = diffae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
bottleneck = diffae_config.get("bottleneck", None)
|
||||
|
||||
pretransform = diffae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
diffusion_downsampling_ratio = None
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion_downsampling_ratio = 1
|
||||
|
||||
return DiffusionAutoencoder(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
diffusion=diffusion,
|
||||
io_channels=io_channels,
|
||||
sample_rate=sample_rate,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform
|
||||
)
|
||||
@@ -1,331 +0,0 @@
|
||||
from functools import reduce
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
from packaging import version
|
||||
|
||||
from dac.nn.layers import Snake1d
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
class ResConvBlock(ResidualBlock):
|
||||
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
||||
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
||||
super().__init__([
|
||||
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_mid),
|
||||
Snake1d(c_mid) if use_snake else nn.GELU(),
|
||||
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
||||
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
||||
], skip)
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm = nn.GroupNorm(1, c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||
|
||||
if not self.use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
# Use flash attention for A100 GPUs
|
||||
self.sdp_kernel_config = (True, False, False)
|
||||
else:
|
||||
# Don't use flash attention for other GPUs
|
||||
self.sdp_kernel_config = (False, True, True)
|
||||
|
||||
def forward(self, input):
|
||||
n, c, s = input.shape
|
||||
qkv = self.qkv_proj(self.norm(input))
|
||||
qkv = qkv.view(
|
||||
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3]**-0.25
|
||||
|
||||
if self.use_flash:
|
||||
with sdp_kernel(*self.sdp_kernel_config):
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
||||
else:
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
||||
|
||||
|
||||
return input + self.dropout(self.out_proj(y))
|
||||
|
||||
class SkipBlock(nn.Module):
|
||||
def __init__(self, *main):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.cat([self.main(input), input], dim=1)
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.weight = nn.Parameter(torch.randn(
|
||||
[out_features // 2, in_features]) * std)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
def expand_to_planes(input, shape):
|
||||
return input[..., None].repeat([1, 1, shape[2]])
|
||||
|
||||
_kernels = {
|
||||
'linear':
|
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
'cubic':
|
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
'lanczos3':
|
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||
}
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv1d(x, weight, stride=2)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
def Downsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
||||
) -> nn.Module:
|
||||
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
||||
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * kernel_multiplier + 1,
|
||||
stride=factor,
|
||||
padding=factor * (kernel_multiplier // 2),
|
||||
)
|
||||
|
||||
|
||||
def Upsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
||||
) -> nn.Module:
|
||||
|
||||
if factor == 1:
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
if use_nearest:
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor=factor, mode="nearest"),
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
return nn.ConvTranspose1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * 2,
|
||||
stride=factor,
|
||||
padding=factor // 2 + factor % 2,
|
||||
output_padding=factor % 2,
|
||||
)
|
||||
|
||||
def zero_init(layer):
|
||||
nn.init.zeros_(layer.weight)
|
||||
if layer.bias is not None:
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
class AdaRMSNorm(nn.Module):
|
||||
def __init__(self, features, cond_features, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps},"
|
||||
|
||||
def forward(self, x, cond):
|
||||
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
||||
|
||||
def normalize(x, eps=1e-4):
|
||||
dim = list(range(1, x.ndim))
|
||||
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
||||
alpha = np.sqrt(n.numel() / x.numel())
|
||||
return x / torch.add(eps, n, alpha=alpha)
|
||||
|
||||
class ForcedWNConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
with torch.no_grad():
|
||||
self.weight.copy_(normalize(self.weight))
|
||||
|
||||
fan_in = self.weight[0].numel()
|
||||
|
||||
w = normalize(self.weight) / math.sqrt(fan_in)
|
||||
|
||||
return F.conv1d(x, w, padding='same')
|
||||
|
||||
# Kernels
|
||||
|
||||
use_compile = True
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
if not use_compile:
|
||||
return function
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
|
||||
@compile
|
||||
def linear_geglu(x, weight, bias=None):
|
||||
x = x @ weight.mT
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
@compile
|
||||
def rms_norm(x, scale, eps):
|
||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||
return x * scale.to(x.dtype)
|
||||
|
||||
# Layers
|
||||
|
||||
class LinearGEGLU(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super().__init__(in_features, out_features * 2, bias=bias)
|
||||
self.out_features = out_features
|
||||
|
||||
def forward(self, x):
|
||||
return linear_geglu(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("scale", torch.ones(shape))
|
||||
else:
|
||||
self.scale = nn.Parameter(torch.ones(shape))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.scale, self.eps)
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# try:
|
||||
# snake_beta = torch.compile(snake_beta)
|
||||
# except RuntimeError:
|
||||
# pass
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
# License available in LICENSES/LICENSE_NVIDIA.txt
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
@@ -1,355 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, is_discrete: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
class DiscreteBottleneck(Bottleneck):
|
||||
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||
super().__init__(is_discrete=True)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.tokens_id = tokens_id
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class TanhBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = torch.tanh(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def compute_mean_kernel(x, y):
|
||||
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||
return torch.exp(-kernel_input).mean()
|
||||
|
||||
def compute_mmd(latents):
|
||||
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||
noise = torch.randn_like(latents_reshaped)
|
||||
|
||||
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||
noise_kernel = compute_mean_kernel(noise, noise)
|
||||
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||
|
||||
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||
return mmd.mean()
|
||||
|
||||
class WassersteinBottleneck(Bottleneck):
|
||||
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
self.bypass_mmd = bypass_mmd
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
if self.training and return_info:
|
||||
if self.bypass_mmd:
|
||||
mmd = torch.tensor(0.0)
|
||||
else:
|
||||
mmd = compute_mmd(x)
|
||||
|
||||
info["mmd"] = mmd
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class L2Bottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = F.normalize(x, dim=1)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return F.normalize(x, dim=1)
|
||||
|
||||
class RVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
info["kl"] = kl
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class FSQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
x = x.to(orig_dtype)
|
||||
|
||||
# Reorder indices to match the expected format
|
||||
indices = rearrange(indices, "b n q -> b q n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
latents = self.quantizer.indices_to_codes(tokens)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,884 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import typing as tp
|
||||
|
||||
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
||||
from .conditioners import MultiConditioner
|
||||
from .dit import DiffusionTransformer
|
||||
from .pretransforms import Pretransform
|
||||
|
||||
from .adp import UNetCFG1d, UNet1d
|
||||
|
||||
# Lazy imports for factory functions to avoid circular imports
|
||||
def _get_create_pretransform_from_config():
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config
|
||||
|
||||
def _get_create_multi_conditioner_from_conditioning_config():
|
||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||
return create_multi_conditioner_from_conditioning_config
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class DiffusionModelWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DiffusionModel,
|
||||
io_channels,
|
||||
sample_size,
|
||||
sample_rate,
|
||||
min_input_length,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.io_channels = io_channels
|
||||
self.sample_size = sample_size
|
||||
self.sample_rate = sample_rate
|
||||
self.min_input_length = min_input_length
|
||||
|
||||
self.model = model
|
||||
|
||||
if pretransform is not None:
|
||||
self.pretransform = pretransform
|
||||
else:
|
||||
self.pretransform = None
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class ConditionedDiffusionModel(nn.Module):
|
||||
def __init__(self,
|
||||
*args,
|
||||
supports_cross_attention: bool = False,
|
||||
supports_input_concat: bool = False,
|
||||
supports_global_cond: bool = False,
|
||||
supports_prepend_cond: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_cross_attention = supports_cross_attention
|
||||
self.supports_input_concat = supports_input_concat
|
||||
self.supports_global_cond = supports_global_cond
|
||||
self.supports_prepend_cond = supports_prepend_cond
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cross_attn_cond: torch.Tensor = None,
|
||||
cross_attn_mask: torch.Tensor = None,
|
||||
input_concat_cond: torch.Tensor = None,
|
||||
global_embed: torch.Tensor = None,
|
||||
prepend_cond: torch.Tensor = None,
|
||||
prepend_cond_mask: torch.Tensor = None,
|
||||
cfg_scale: float = 1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
**kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class ConditionedDiffusionModelWrapper(nn.Module):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: ConditionedDiffusionModel,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
zero_init: bool = False,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
sync_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.sync_cond_ids = sync_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
if zero_init is True:
|
||||
self.conditioner.apply(_basic_init)
|
||||
self.model.model.initialize_weights()
|
||||
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
sync_input = None
|
||||
|
||||
if len(self.cross_attn_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
cross_attention_input = []
|
||||
cross_attention_masks = []
|
||||
|
||||
for key in self.cross_attn_cond_ids:
|
||||
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(cross_attn_in.shape) == 2:
|
||||
cross_attn_in = cross_attn_in.unsqueeze(1)
|
||||
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
||||
|
||||
cross_attention_input.append(cross_attn_in)
|
||||
cross_attention_masks.append(cross_attn_mask)
|
||||
|
||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||
|
||||
if len(self.add_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
add_input = []
|
||||
|
||||
for key in self.add_cond_ids:
|
||||
add_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(add_in.shape) == 2:
|
||||
add_in = add_in.unsqueeze(1)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
add_input.append(add_in)
|
||||
|
||||
add_input = torch.cat(add_input, dim=2)
|
||||
|
||||
if len(self.sync_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
sync_input = []
|
||||
|
||||
for key in self.sync_cond_ids:
|
||||
sync_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(sync_in.shape) == 2:
|
||||
sync_in = sync_in.unsqueeze(1)
|
||||
sync_input.append(sync_in)
|
||||
|
||||
sync_input = torch.cat(sync_input, dim=2)
|
||||
|
||||
if len(self.global_cond_ids) > 0:
|
||||
# Concatenate all global conditioning inputs over the channel dimension
|
||||
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||
global_conds = []
|
||||
for key in self.global_cond_ids:
|
||||
global_cond_input = conditioning_tensors[key][0]
|
||||
if len(global_cond_input.shape) == 2:
|
||||
global_cond_input = global_cond_input.unsqueeze(1)
|
||||
global_conds.append(global_cond_input)
|
||||
|
||||
# # Concatenate over the channel dimension
|
||||
# if global_conds[0].shape[-1] == 768:
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
# else:
|
||||
# global_cond = sum(global_conds)
|
||||
global_cond = sum(global_conds)
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
|
||||
if len(global_cond.shape) == 3:
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
if len(self.input_concat_ids) > 0:
|
||||
# Concatenate all input concat conditioning inputs over the channel dimension
|
||||
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
||||
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
||||
|
||||
if len(self.prepend_cond_ids) > 0:
|
||||
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||
prepend_conds = []
|
||||
prepend_cond_masks = []
|
||||
|
||||
for key in self.prepend_cond_ids:
|
||||
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
||||
if len(prepend_cond_input.shape) == 2:
|
||||
prepend_cond_input = prepend_cond_input.unsqueeze(1)
|
||||
prepend_conds.append(prepend_cond_input)
|
||||
prepend_cond_masks.append(prepend_cond_mask)
|
||||
|
||||
prepend_cond = torch.cat(prepend_conds, dim=1)
|
||||
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
||||
|
||||
if negative:
|
||||
return {
|
||||
"negative_cross_attn_cond": cross_attention_input,
|
||||
"negative_cross_attn_mask": cross_attention_masks,
|
||||
"negative_global_cond": global_cond,
|
||||
"negative_input_concat_cond": input_concat_cond
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"cross_attn_cond": cross_attention_input,
|
||||
"cross_attn_mask": cross_attention_masks,
|
||||
"global_cond": global_cond,
|
||||
"input_concat_cond": input_concat_cond,
|
||||
"prepend_cond": prepend_cond,
|
||||
"prepend_cond_mask": prepend_cond_mask,
|
||||
"add_cond": add_input,
|
||||
"sync_cond": sync_input
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNetCFG1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
**kwargs):
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
embedding=cross_attn_cond,
|
||||
embedding_mask=cross_attn_mask,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
embedding_scale=cfg_scale,
|
||||
embedding_mask_proba=cfg_dropout_prob,
|
||||
batch_cfg=batch_cfg,
|
||||
rescale_cfg=rescale_cfg,
|
||||
negative_embedding=negative_cross_attn_cond,
|
||||
negative_embedding_mask=negative_cross_attn_mask,
|
||||
**kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNet1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
**kwargs):
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
if input_concat_cond.shape[2] != x.shape[2]:
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
**kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
class UNet1DUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = in_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
||||
|
||||
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
**kwargs):
|
||||
|
||||
return self.model(x, t, cond = input_concat_cond)
|
||||
|
||||
class DiffusionAttnUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels = 2,
|
||||
depth=14,
|
||||
n_attn_layers = 6,
|
||||
channels = [128, 128, 256, 256] + [512] * 10,
|
||||
cond_dim = 0,
|
||||
cond_noise_aug = False,
|
||||
kernel_size = 5,
|
||||
learned_resample = False,
|
||||
strides = [2] * 13,
|
||||
conv_bias = True,
|
||||
use_snake = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_noise_aug = cond_noise_aug
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
if self.cond_noise_aug:
|
||||
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||
|
||||
self.timestep_embed = FourierFeatures(1, 16)
|
||||
|
||||
attn_layer = depth - n_attn_layers
|
||||
|
||||
strides = [1] + strides
|
||||
|
||||
block = nn.Identity()
|
||||
|
||||
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
||||
|
||||
for i in range(depth, 0, -1):
|
||||
c = channels[i - 1]
|
||||
stride = strides[i-1]
|
||||
if stride > 2 and not learned_resample:
|
||||
raise ValueError("Must have stride 2 without learned resampling")
|
||||
|
||||
if i > 1:
|
||||
c_prev = channels[i - 2]
|
||||
add_attn = i >= attn_layer and n_attn_layers > 0
|
||||
block = SkipBlock(
|
||||
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
||||
conv_block(c_prev, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
block,
|
||||
conv_block(c * 2 if i != depth else c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c_prev),
|
||||
SelfAttention1d(c_prev, c_prev //
|
||||
32) if add_attn else nn.Identity(),
|
||||
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
||||
block = nn.Sequential(
|
||||
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, c),
|
||||
block,
|
||||
conv_block(c * 2, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, io_channels, is_last=True),
|
||||
)
|
||||
self.net = block
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.net.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
||||
|
||||
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
||||
|
||||
inputs = [x, timestep_embed]
|
||||
|
||||
if cond is not None:
|
||||
if cond.shape[2] != x.shape[2]:
|
||||
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
||||
|
||||
if self.cond_noise_aug:
|
||||
# Get a random number between 0 and 1, uniformly sampled
|
||||
if cond_aug_scale is None:
|
||||
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
||||
else:
|
||||
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
||||
|
||||
# Add noise to the conditioning signal
|
||||
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
||||
|
||||
# Get embedding for noise cond level, reusing timestamp_embed
|
||||
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
||||
|
||||
inputs.append(aug_level_embed)
|
||||
|
||||
inputs.append(cond)
|
||||
|
||||
outputs = self.net(torch.cat(inputs, dim=1))
|
||||
|
||||
return outputs
|
||||
|
||||
class DiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = DiffusionTransformer(*args, **kwargs)
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
global_cond=None,
|
||||
negative_global_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_mask,
|
||||
negative_cross_attn_cond=negative_cross_attn_cond,
|
||||
negative_cross_attn_mask=negative_cross_attn_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
global_embed=global_cond,
|
||||
**kwargs)
|
||||
|
||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
mm_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
self.mm_cond_ids = mm_cond_ids
|
||||
|
||||
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
|
||||
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
inpaint_masked_input = None
|
||||
t5_features = None
|
||||
metaclip_global_text_features = None
|
||||
clip_f = conditioning_tensors["metaclip_features"]
|
||||
sync_f = conditioning_tensors["sync_features"]
|
||||
text_f = conditioning_tensors["metaclip_text_features"]
|
||||
if 'inpaint_masked_input' in conditioning_tensors.keys():
|
||||
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
|
||||
if 't5_features' in conditioning_tensors.keys():
|
||||
t5_features = conditioning_tensors["t5_features"]
|
||||
if 'metaclip_global_text_features' in conditioning_tensors.keys():
|
||||
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
|
||||
return {
|
||||
"clip_f": clip_f,
|
||||
"sync_f": sync_f,
|
||||
"text_f": text_f,
|
||||
"inpaint_masked_input": inpaint_masked_input,
|
||||
"t5_features": t5_features,
|
||||
"metaclip_global_text_features": metaclip_global_text_features
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class DiTUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
model_type = diffusion_uncond_config.get('type', None)
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('config', {})
|
||||
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
model = UNet1DUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('diffusion', {})
|
||||
model_type = diffusion_config.get('type', None)
|
||||
model_config = diffusion_config.get("config",{})
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**model_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
io_channels = model_config.get("io_channels", 64)
|
||||
model = UNet1DUncondWrapper(
|
||||
io_channels = io_channels,
|
||||
**model_config
|
||||
)
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**model_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
|
||||
model_config = config["model"]
|
||||
|
||||
model_type = config["model_type"]
|
||||
|
||||
diffusion_config = model_config.get('diffusion', None)
|
||||
assert diffusion_config is not None, "Must specify diffusion config"
|
||||
|
||||
diffusion_model_type = diffusion_config.get('type', None)
|
||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||
|
||||
diffusion_model_config = diffusion_config.get('config', None)
|
||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||
|
||||
if diffusion_model_type == 'adp_cfg_1d':
|
||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'adp_1d':
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
|
||||
sample_rate = config.get('sample_rate', None)
|
||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||
|
||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||
|
||||
conditioning_config = model_config.get('conditioning', None)
|
||||
|
||||
conditioner = None
|
||||
if conditioning_config is not None:
|
||||
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
|
||||
|
||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||
zero_init = diffusion_config.get('zero_init', False)
|
||||
pretransform = model_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
min_input_length *= diffusion_model.model.patch_size
|
||||
|
||||
# Get the proper wrapper class
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if model_type == "mm_diffusion_cond":
|
||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||
|
||||
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
min_input_length=min_input_length,
|
||||
sample_rate=sample_rate,
|
||||
cross_attn_cond_ids=cross_attention_ids,
|
||||
global_cond_ids=global_cond_ids,
|
||||
input_concat_ids=input_concat_ids,
|
||||
prepend_cond_ids=prepend_cond_ids,
|
||||
add_cond_ids=add_cond_ids,
|
||||
sync_cond_ids=sync_cond_ids,
|
||||
pretransform=pretransform,
|
||||
io_channels=io_channels,
|
||||
zero_init=zero_init,
|
||||
**extra_kwargs
|
||||
)
|
||||
@@ -1,539 +0,0 @@
|
||||
import typing as tp
|
||||
import math
|
||||
import torch
|
||||
# from beartype.typing import Tuple
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from .blocks import FourierFeatures
|
||||
from .transformer import ContinuousTransformer
|
||||
from .utils import mask_from_frac_lengths, resample
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
io_channels=32,
|
||||
patch_size=1,
|
||||
embed_dim=768,
|
||||
cond_token_dim=0,
|
||||
project_cond_tokens=True,
|
||||
global_cond_dim=0,
|
||||
project_global_cond=True,
|
||||
input_concat_dim=0,
|
||||
prepend_cond_dim=0,
|
||||
cond_ctx_dim=0,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
|
||||
add_token_dim=0,
|
||||
sync_token_dim=0,
|
||||
use_mlp=False,
|
||||
use_zero_init=False,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cond_token_dim = cond_token_dim
|
||||
|
||||
# Timestep embeddings
|
||||
timestep_features_dim = 256
|
||||
# Timestep embeddings
|
||||
self.timestep_cond_type = timestep_cond_type
|
||||
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
||||
|
||||
if timestep_cond_type == "global":
|
||||
timestep_embed_dim = embed_dim
|
||||
elif timestep_cond_type == "input_concat":
|
||||
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
|
||||
input_concat_dim += timestep_embed_dim
|
||||
|
||||
self.to_timestep_embed = nn.Sequential(
|
||||
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=True),
|
||||
)
|
||||
self.use_mlp = use_mlp
|
||||
if cond_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_cond_embed = nn.Sequential(
|
||||
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 0
|
||||
|
||||
if global_cond_dim > 0:
|
||||
# Global conditioning
|
||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||
self.to_global_embed = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
||||
)
|
||||
if add_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_add_embed = nn.Sequential(
|
||||
nn.Linear(add_token_dim, add_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
add_embed_dim = 0
|
||||
|
||||
if sync_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_sync_embed = nn.Sequential(
|
||||
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
sync_embed_dim = 0
|
||||
|
||||
|
||||
if prepend_cond_dim > 0:
|
||||
# Prepend conditioning
|
||||
self.to_prepend_embed = nn.Sequential(
|
||||
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
)
|
||||
|
||||
self.input_concat_dim = input_concat_dim
|
||||
|
||||
dim_in = io_channels + self.input_concat_dim
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Transformer
|
||||
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.global_cond_type = global_cond_type
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
|
||||
global_dim = None
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
# The global conditioning is projected to the embed_dim already at this point
|
||||
global_dim = embed_dim
|
||||
|
||||
self.transformer = ContinuousTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
dim_heads=embed_dim // num_heads,
|
||||
dim_in=dim_in * patch_size,
|
||||
dim_out=io_channels * patch_size,
|
||||
cross_attend = cond_token_dim > 0,
|
||||
cond_token_dim = cond_embed_dim,
|
||||
global_cond_dim=global_dim,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||
|
||||
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
||||
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
||||
nn.init.zeros_(self.preprocess_conv.weight)
|
||||
nn.init.zeros_(self.postprocess_conv.weight)
|
||||
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
# if isinstance(module, nn.Conv1d):
|
||||
# if module.bias is not None:
|
||||
# nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
|
||||
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
if self.global_cond_type == "adaLN":
|
||||
for block in self.transformer.layers:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
nn.init.constant_(self.empty_clip_feat, 0)
|
||||
nn.init.constant_(self.empty_sync_feat, 0)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
mask=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
add_masks=None,
|
||||
sync_cond=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||
|
||||
if global_embed is not None:
|
||||
# Project the global conditioning to the embedding dimension
|
||||
global_embed = self.to_global_embed(global_embed)
|
||||
|
||||
prepend_inputs = None
|
||||
prepend_mask = None
|
||||
prepend_length = 0
|
||||
if prepend_cond is not None:
|
||||
# Project the prepend conditioning to the embedding dimension
|
||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||
|
||||
prepend_inputs = prepend_cond
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_mask = prepend_cond_mask
|
||||
|
||||
if input_concat_cond is not None:
|
||||
# reshape from (b, n, c) to (b, c, n)
|
||||
if input_concat_cond.shape[1] != x.shape[1]:
|
||||
input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
# if input_concat_cond.shape[1] != x.shape[2]:
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# if len(global_embed.shape) == 2:
|
||||
# global_embed = global_embed.unsqueeze(1)
|
||||
# global_embed = global_embed + input_concat_cond
|
||||
x = torch.cat([x, input_concat_cond], dim=1)
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if self.timestep_cond_type == "global":
|
||||
if global_embed is not None:
|
||||
if len(global_embed.shape) == 3:
|
||||
timestep_embed = timestep_embed.unsqueeze(1)
|
||||
global_embed = global_embed + timestep_embed
|
||||
else:
|
||||
global_embed = timestep_embed
|
||||
elif self.timestep_cond_type == "input_concat":
|
||||
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
|
||||
|
||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||
if self.global_cond_type == "prepend" and global_embed is not None:
|
||||
if prepend_inputs is None:
|
||||
# Prepend inputs are just the global embed, and the mask is all ones
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = global_embed.unsqueeze(1)
|
||||
else:
|
||||
prepend_inputs = global_embed
|
||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||
else:
|
||||
# Prepend inputs are the prepend conditioning + the global embed
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||
else:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
|
||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||
|
||||
prepend_length = prepend_inputs.shape[1]
|
||||
|
||||
x = self.preprocess_conv(x) + x
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
|
||||
extra_args = {}
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
extra_args["global_cond"] = global_embed
|
||||
|
||||
if self.patch_size > 1:
|
||||
b, seq_len, c = x.shape
|
||||
|
||||
# 计算需要填充的数量
|
||||
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
|
||||
|
||||
if pad_amount > 0:
|
||||
# 在时间维度上进行填充
|
||||
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
|
||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||
|
||||
if add_cond is not None:
|
||||
# Interpolate add_cond to the same length as x
|
||||
# if self.use_mlp:
|
||||
add_cond = self.to_add_embed(add_cond)
|
||||
if add_cond.shape[1] != x.shape[1]:
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
# add_cond = resample(add_cond, x)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = self.to_sync_embed(sync_cond)
|
||||
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||
|
||||
if return_info:
|
||||
output, info = output
|
||||
|
||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||
|
||||
if self.patch_size > 1:
|
||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||
# 移除之前添加的填充
|
||||
if pad_amount > 0:
|
||||
output = output[:, :, :seq_len]
|
||||
|
||||
output = self.postprocess_conv(output) + output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
negative_global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
sync_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=0.0,
|
||||
causal=False,
|
||||
scale_phi=0.0,
|
||||
mask=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
||||
bsz, a, b = x.shape
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
t = t.to(model_dtype)
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = cross_attn_cond.to(model_dtype)
|
||||
|
||||
if negative_cross_attn_cond is not None:
|
||||
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
|
||||
|
||||
if input_concat_cond is not None:
|
||||
input_concat_cond = input_concat_cond.to(model_dtype)
|
||||
|
||||
if global_embed is not None:
|
||||
global_embed = global_embed.to(model_dtype)
|
||||
|
||||
if negative_global_embed is not None:
|
||||
negative_global_embed = negative_global_embed.to(model_dtype)
|
||||
|
||||
if prepend_cond is not None:
|
||||
prepend_cond = prepend_cond.to(model_dtype)
|
||||
|
||||
if add_cond is not None:
|
||||
add_cond = add_cond.to(model_dtype)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = sync_cond.to(model_dtype)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
||||
|
||||
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_cond_mask = prepend_cond_mask.bool()
|
||||
|
||||
|
||||
# CFG dropout
|
||||
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
|
||||
if cross_attn_cond is not None:
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||
|
||||
if prepend_cond is not None:
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||
|
||||
if add_cond is not None:
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
|
||||
add_cond = torch.where(dropout_mask, null_embed, add_cond)
|
||||
|
||||
if sync_cond is not None:
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
|
||||
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
|
||||
|
||||
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
|
||||
# Classifier-free guidance
|
||||
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
||||
batch_inputs = torch.cat([x, x], dim=0)
|
||||
batch_timestep = torch.cat([t, t], dim=0)
|
||||
if global_embed is not None and global_embed.shape[0] == bsz:
|
||||
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
||||
elif global_embed is not None:
|
||||
batch_global_cond = global_embed
|
||||
else:
|
||||
batch_global_cond = None
|
||||
|
||||
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
|
||||
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
||||
elif input_concat_cond is not None:
|
||||
batch_input_concat_cond = input_concat_cond
|
||||
else:
|
||||
batch_input_concat_cond = None
|
||||
|
||||
batch_cond = None
|
||||
batch_cond_masks = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
|
||||
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
||||
if negative_cross_attn_cond is not None:
|
||||
|
||||
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
||||
if negative_cross_attn_mask is not None:
|
||||
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
||||
|
||||
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
||||
|
||||
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
||||
|
||||
else:
|
||||
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
||||
elif cross_attn_cond is not None:
|
||||
batch_cond = cross_attn_cond
|
||||
else:
|
||||
batch_cond = None
|
||||
|
||||
batch_prepend_cond = None
|
||||
batch_prepend_cond_mask = None
|
||||
|
||||
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
|
||||
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||
elif prepend_cond is not None:
|
||||
batch_prepend_cond = prepend_cond
|
||||
else:
|
||||
batch_prepend_cond = None
|
||||
|
||||
batch_add_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if add_cond is not None and add_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
|
||||
|
||||
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
|
||||
elif add_cond is not None:
|
||||
batch_add_cond = add_cond
|
||||
else:
|
||||
batch_add_cond = None
|
||||
|
||||
batch_sync_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if sync_cond is not None and sync_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
|
||||
|
||||
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
|
||||
elif sync_cond is not None:
|
||||
batch_sync_cond = sync_cond
|
||||
else:
|
||||
batch_sync_cond = None
|
||||
|
||||
if mask is not None:
|
||||
batch_masks = torch.cat([mask, mask], dim=0)
|
||||
else:
|
||||
batch_masks = None
|
||||
|
||||
batch_output = self._forward(
|
||||
batch_inputs,
|
||||
batch_timestep,
|
||||
cross_attn_cond=batch_cond,
|
||||
cross_attn_cond_mask=batch_cond_masks,
|
||||
mask = batch_masks,
|
||||
input_concat_cond=batch_input_concat_cond,
|
||||
global_embed = batch_global_cond,
|
||||
prepend_cond = batch_prepend_cond,
|
||||
prepend_cond_mask = batch_prepend_cond_mask,
|
||||
add_cond = batch_add_cond,
|
||||
sync_cond = batch_sync_cond,
|
||||
return_info = return_info,
|
||||
**kwargs)
|
||||
|
||||
if return_info:
|
||||
batch_output, info = batch_output
|
||||
|
||||
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
||||
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
||||
|
||||
# CFG Rescale
|
||||
if scale_phi != 0.0:
|
||||
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
||||
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
||||
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
||||
else:
|
||||
output = cfg_output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
else:
|
||||
return self._forward(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_cond_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
global_embed=global_embed,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
add_cond=add_cond,
|
||||
sync_cond=sync_cond,
|
||||
mask=mask,
|
||||
return_info=return_info,
|
||||
**kwargs
|
||||
)
|
||||
@@ -1,275 +0,0 @@
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from .blocks import AdaRMSNorm
|
||||
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||
from .utils import checkpoint
|
||||
|
||||
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||
class ContinuousLocalTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
causal = False,
|
||||
local_attn_window_size = 64,
|
||||
heads = 8,
|
||||
ff_mult = 2,
|
||||
cond_dim = 0,
|
||||
cross_attn_cond_dim = 0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_head = dim//heads
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
||||
|
||||
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
||||
|
||||
self.local_attn_window_size = local_attn_window_size
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
self.cross_attn_cond_dim = cross_attn_cond_dim
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
||||
|
||||
for _ in range(depth):
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
causal=causal,
|
||||
zero_init_output=True,
|
||||
natten_kernel_size=local_attn_window_size,
|
||||
),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
dim_context = cross_attn_cond_dim,
|
||||
zero_init_output=True
|
||||
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
||||
]))
|
||||
|
||||
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
if prepend_cond is not None:
|
||||
x = torch.cat([prepend_cond, x], dim=1)
|
||||
|
||||
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
|
||||
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
||||
|
||||
residual = x
|
||||
if cond is not None:
|
||||
x = checkpoint(attn_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(attn_norm, x)
|
||||
|
||||
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
||||
|
||||
residual = x
|
||||
|
||||
if cond is not None:
|
||||
x = checkpoint(ff_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(ff_norm, x)
|
||||
|
||||
x = checkpoint(ff, x) + residual
|
||||
|
||||
return checkpoint(self.project_out, x)
|
||||
|
||||
class TransformerDownsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim = 768,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
downsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size=local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
# Trade sequence length for channels
|
||||
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
||||
|
||||
# Project back to embed dim
|
||||
x = checkpoint(self.project_down, x)
|
||||
|
||||
return x
|
||||
|
||||
class TransformerUpsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
upsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.upsample_ratio = upsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Project to embed dim
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Project to increase channel dim
|
||||
x = checkpoint(self.project_up, x)
|
||||
|
||||
# Trade channels for sequence length
|
||||
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [96, 192, 384, 768],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerDownsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
downsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [768, 384, 192, 96],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerUpsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
upsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
return x
|
||||
@@ -1 +0,0 @@
|
||||
# mmmodules package
|
||||
@@ -1 +0,0 @@
|
||||
# mmmodules.model package
|
||||
@@ -1,393 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from scipy.optimize import fmin
|
||||
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
||||
|
||||
class PQMF(nn.Module):
|
||||
"""
|
||||
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
||||
Uses polyphase representation which is computationally more efficient for real-time.
|
||||
|
||||
Parameters:
|
||||
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
||||
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
||||
"""
|
||||
|
||||
def __init__(self, attenuation, num_bands):
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# Ensure num_bands is a power of 2
|
||||
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
||||
assert is_power_of_2, "'num_bands' must be a power of 2."
|
||||
|
||||
# Create the prototype filter
|
||||
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
||||
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
||||
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
||||
|
||||
# Register filters and settings
|
||||
self.register_buffer("filter_bank", padded_filter_bank)
|
||||
self.register_buffer("prototype", prototype_filter)
|
||||
self.num_bands = num_bands
|
||||
|
||||
def forward(self, signal):
|
||||
"""Decompose the signal into multiple frequency bands."""
|
||||
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
||||
signal = prepare_signal_dimensions(signal)
|
||||
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
||||
signal = pad_signal(signal, self.num_bands)
|
||||
# run it
|
||||
signal = polyphase_analysis(signal, self.filter_bank)
|
||||
return apply_alias_cancellation(signal)
|
||||
|
||||
def inverse(self, bands):
|
||||
"""Reconstruct the original signal from the frequency bands."""
|
||||
bands = apply_alias_cancellation(bands)
|
||||
return polyphase_synthesis(bands, self.filter_bank)
|
||||
|
||||
|
||||
def prepare_signal_dimensions(signal):
|
||||
"""
|
||||
Rearrange signal into Batch x Channels x Length.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor or numpy.ndarray
|
||||
The input signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Preprocessed signal tensor.
|
||||
"""
|
||||
# Convert numpy to torch tensor
|
||||
if isinstance(signal, np.ndarray):
|
||||
signal = torch.from_numpy(signal)
|
||||
|
||||
# Ensure tensor
|
||||
if not isinstance(signal, torch.Tensor):
|
||||
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
||||
|
||||
# Modify dimension of signal to Batch x Channels x Length
|
||||
if signal.dim() == 1:
|
||||
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
||||
signal = signal.unsqueeze(0).unsqueeze(0)
|
||||
elif signal.dim() == 2:
|
||||
# This is a multi-channel signal (e.g. stereo)
|
||||
# Rearrange so that larger dimension (Length) is last
|
||||
if signal.shape[0] > signal.shape[1]:
|
||||
signal = signal.T
|
||||
# Unsqueeze to 1 x Channels x Length
|
||||
signal = signal.unsqueeze(0)
|
||||
return signal
|
||||
|
||||
def pad_signal(signal, num_bands):
|
||||
"""
|
||||
Pads the signal to make its length divisible by the given number of bands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
The input signal tensor, where the last dimension represents the signal length.
|
||||
|
||||
num_bands : int
|
||||
The number of bands by which the signal length should be divisible.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The padded signal tensor. If the original signal length was already divisible
|
||||
by num_bands, returns the original signal unchanged.
|
||||
"""
|
||||
remainder = signal.shape[-1] % num_bands
|
||||
if remainder > 0:
|
||||
padding_size = num_bands - remainder
|
||||
signal = nn.functional.pad(signal, (0, padding_size))
|
||||
return signal
|
||||
|
||||
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
||||
"""
|
||||
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prototype_filter : torch.Tensor
|
||||
The prototype filter used as the basis for modulation.
|
||||
num_bands : int
|
||||
The number of desired subbands or filters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A bank of cosine modulated filters.
|
||||
"""
|
||||
|
||||
# Initialize indices for modulation.
|
||||
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
||||
|
||||
# Calculate the length of the prototype filter.
|
||||
filter_length = prototype_filter.shape[-1]
|
||||
|
||||
# Generate symmetric time indices centered around zero.
|
||||
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
||||
|
||||
# Calculate phase offsets to ensure orthogonality between subbands.
|
||||
phase_offsets = (-1)**subband_indices * np.pi / 4
|
||||
|
||||
# Compute the cosine modulation function.
|
||||
modulation = torch.cos(
|
||||
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
||||
)
|
||||
|
||||
# Apply modulation to the prototype filter.
|
||||
modulated_filters = 2 * prototype_filter * modulation
|
||||
|
||||
return modulated_filters
|
||||
|
||||
|
||||
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
||||
"""
|
||||
Design a lowpass filter using the Kaiser window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
The angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
The desired stopband attenuation in decibels (dB).
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The designed lowpass filter coefficients.
|
||||
"""
|
||||
|
||||
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
||||
|
||||
# Ensure the estimated length is odd.
|
||||
estimated_length = 2 * (estimated_length // 2) + 1
|
||||
|
||||
if filter_length is None:
|
||||
filter_length = estimated_length
|
||||
|
||||
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
||||
|
||||
|
||||
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
||||
"""
|
||||
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
Angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
Desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The computed objective (loss) value for the given filter specs.
|
||||
"""
|
||||
|
||||
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
||||
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
||||
|
||||
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
||||
|
||||
|
||||
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
||||
"""
|
||||
Design the optimal prototype filter for a multiband system given the desired specs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attenuation : float
|
||||
The desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The optimal prototype filter coefficients.
|
||||
"""
|
||||
|
||||
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
||||
1 / num_bands, disp=0)[0]
|
||||
|
||||
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
||||
return torch.tensor(prototype_filter, dtype=torch.float32)
|
||||
|
||||
def pad_to_nearest_power_of_two(x):
|
||||
"""
|
||||
Pads the input tensor 'x' on both sides such that its last dimension
|
||||
becomes the nearest larger power of two.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor to be padded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The padded tensor.
|
||||
"""
|
||||
current_length = x.shape[-1]
|
||||
target_length = 2**math.ceil(math.log2(current_length))
|
||||
|
||||
total_padding = target_length - current_length
|
||||
left_padding = total_padding // 2
|
||||
right_padding = total_padding - left_padding
|
||||
|
||||
return nn.functional.pad(x, (left_padding, right_padding))
|
||||
|
||||
def apply_alias_cancellation(x):
|
||||
"""
|
||||
Applies alias cancellation by inverting the sign of every
|
||||
second element of every second row, starting from the second
|
||||
row's first element in a tensor.
|
||||
|
||||
This operation helps ensure that the aliasing introduced in
|
||||
each band during the decomposition will be counteracted during
|
||||
the reconstruction.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Tensor with specific elements' sign inverted for alias cancellation.
|
||||
"""
|
||||
|
||||
# Create a mask of the same shape as 'x', initialized with all ones
|
||||
mask = torch.ones_like(x)
|
||||
|
||||
# Update specific elements in the mask to -1 to perform inversion
|
||||
mask[..., 1::2, ::2] = -1
|
||||
|
||||
# Apply the mask to the input tensor 'x'
|
||||
return x * mask
|
||||
|
||||
def ensure_odd_length(tensor):
|
||||
"""
|
||||
Pads the last dimension of a tensor to ensure its size is odd.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor whose last dimension might need padding.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The original tensor if its last dimension was already odd,
|
||||
or the padded tensor with an odd-sized last dimension.
|
||||
"""
|
||||
|
||||
last_dim_size = tensor.shape[-1]
|
||||
|
||||
if last_dim_size % 2 == 0:
|
||||
tensor = nn.functional.pad(tensor, (0, 1))
|
||||
|
||||
return tensor
|
||||
|
||||
def polyphase_analysis(signal, filter_bank):
|
||||
"""
|
||||
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
signal : torch.Tensor
|
||||
Input signal tensor with shape (Batch x Channels x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Filter bank tensor with shape (Bands x Length).
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange signal for polyphase processing.
|
||||
# Also combine Batch x Channel into one dimension for now.
|
||||
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
||||
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
||||
|
||||
# Rearrange the filter bank for matching signal shape
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
||||
|
||||
# Apply convolution with appropriate padding to maintain spatial dimensions
|
||||
padding = filter_bank.shape[-1] // 2
|
||||
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
||||
|
||||
# Truncate the last dimension post-convolution to adjust the output shape
|
||||
filtered_signal = filtered_signal[..., :-1]
|
||||
# Rearrange the first dimension back into Batch x Channels
|
||||
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
def polyphase_synthesis(signal, filter_bank):
|
||||
"""
|
||||
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Analysis filter bank (shape: Bands x Length).
|
||||
|
||||
should_rearrange : bool, optional
|
||||
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Reconstructed signal (shape: Batch x Channels X Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange the filter bank
|
||||
filter_bank = filter_bank.flip(-1)
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
||||
|
||||
# Combine Batch x Channels into one dimension for now.
|
||||
signal = rearrange(signal, "b c n t -> (b c) n t")
|
||||
|
||||
# Apply convolution with appropriate padding
|
||||
padding_amount = filter_bank.shape[-1] // 2 + 1
|
||||
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
||||
|
||||
# Scale the result
|
||||
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
||||
|
||||
# Reorganize the output and truncate
|
||||
reconstructed_signal = reconstructed_signal.flip(1)
|
||||
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
||||
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
||||
|
||||
return reconstructed_signal
|
||||
@@ -1,239 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class Pretransform(nn.Module):
|
||||
def __init__(self, enable_grad, io_channels, is_discrete):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
self.io_channels = io_channels
|
||||
self.encoded_channels = None
|
||||
self.downsampling_ratio = None
|
||||
|
||||
self.enable_grad = enable_grad
|
||||
|
||||
def encode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, z):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
class AutoencoderPretransform(Pretransform):
|
||||
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
||||
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
||||
self.model = model
|
||||
self.model.requires_grad_(False).eval()
|
||||
self.scale=scale
|
||||
self.downsampling_ratio = model.downsampling_ratio
|
||||
self.io_channels = model.io_channels
|
||||
self.sample_rate = model.sample_rate
|
||||
|
||||
self.model_half = model_half
|
||||
self.iterate_batch = iterate_batch
|
||||
|
||||
self.encoded_channels = model.latent_dim
|
||||
|
||||
self.chunked = chunked
|
||||
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
|
||||
if self.model_half:
|
||||
self.model.half()
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
if self.model_half:
|
||||
x = x.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
encoded = encoded.float()
|
||||
|
||||
return encoded / self.scale
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = z * self.scale
|
||||
|
||||
if self.model_half:
|
||||
z = z.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
decoded = decoded.float()
|
||||
|
||||
return decoded
|
||||
|
||||
def tokenize(self, x, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
||||
|
||||
_, info = self.model.encode(x, return_info = True, **kwargs)
|
||||
|
||||
return info[self.model.bottleneck.tokens_id]
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
||||
|
||||
return self.model.decode_tokens(tokens, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
class PQMFPretransform(Pretransform):
|
||||
def __init__(self, attenuation=100, num_bands=16):
|
||||
# TODO: Fix PQMF to take in in-channels
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
||||
from .pqmf import PQMF
|
||||
self.pqmf = PQMF(attenuation, num_bands)
|
||||
|
||||
|
||||
def encode(self, x):
|
||||
# x is (Batch x Channels x Time)
|
||||
x = self.pqmf.forward(x)
|
||||
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
||||
# but Pretransform needs Batch x Channels x Time
|
||||
# so concatenate channels and bands into one axis
|
||||
return rearrange(x, "b c n t -> b (c n) t")
|
||||
|
||||
def decode(self, x):
|
||||
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
||||
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
||||
# returns (Batch x Channels x Time)
|
||||
return self.pqmf.inverse(x)
|
||||
|
||||
class PretrainedDACPretransform(Pretransform):
|
||||
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
import dac
|
||||
|
||||
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
||||
|
||||
self.model = dac.DAC.load(model_path)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
if model_type == "44khz":
|
||||
self.downsampling_ratio = 512
|
||||
else:
|
||||
self.downsampling_ratio = 320
|
||||
|
||||
self.io_channels = 1
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.chunked = chunked
|
||||
|
||||
self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.n_codebooks
|
||||
|
||||
self.codebook_size = self.model.codebook_size
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
latents = self.model.encoder(x)
|
||||
|
||||
if self.quantize_on_decode:
|
||||
output = latents
|
||||
else:
|
||||
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
output = z
|
||||
|
||||
if self.scale != 1.0:
|
||||
output = output / self.scale
|
||||
|
||||
return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
if self.scale != 1.0:
|
||||
z = z * self.scale
|
||||
|
||||
if self.quantize_on_decode:
|
||||
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
return self.model.encode(x)[1]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
latents = self.model.quantizer.from_codes(tokens)
|
||||
return self.model.decode(latents)
|
||||
|
||||
class AudiocraftCompressionPretransform(Pretransform):
|
||||
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
try:
|
||||
from audiocraft.models import CompressionModel
|
||||
except ImportError:
|
||||
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
||||
|
||||
self.model = CompressionModel.get_pretrained(model_type)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
||||
|
||||
self.sample_rate = self.model.sample_rate
|
||||
|
||||
self.io_channels = self.model.channels
|
||||
|
||||
self.scale = scale
|
||||
|
||||
#self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.num_codebooks
|
||||
|
||||
self.codebook_size = self.model.cardinality
|
||||
|
||||
self.model.to(torch.float16).eval().requires_grad_(False)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous encoding"
|
||||
|
||||
# latents = self.model.encoder(x)
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# output = latents
|
||||
# else:
|
||||
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
# output = z
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# output = output / self.scale
|
||||
|
||||
# return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous decoding"
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# z = z * self.scale
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
# return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.encode(x.to(torch.float16))[0]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.decode(tokens)
|
||||
@@ -1,989 +0,0 @@
|
||||
from functools import reduce, partial
|
||||
from packaging import version
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from torch.cuda.amp import autocast
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from typing import Callable, Literal
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
flash_attn_kvpacked_func = None
|
||||
flash_attn_func = None
|
||||
|
||||
from .utils import compile, checkpoint
|
||||
try:
|
||||
import natten
|
||||
except ImportError:
|
||||
natten = None
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
||||
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
||||
|
||||
def create_causal_mask(i, j, device):
|
||||
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||
|
||||
def or_reduce(masks):
|
||||
head, *body = masks
|
||||
for rest in body:
|
||||
head = head | rest
|
||||
return head
|
||||
|
||||
# positional embeddings
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||
|
||||
half_dim = dim // 2
|
||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||
inv_freq = theta ** -freq_seq
|
||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = pos - seq_start_pos[..., None]
|
||||
|
||||
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
use_xpos = False,
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer('scale', None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer('scale', scale)
|
||||
|
||||
def forward_from_seq_len(self, seq_len):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = torch.arange(seq_len, device = device)
|
||||
return self.forward(t)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def forward(self, t):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = t.to(torch.float32)
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||
x1, x2 = x.unbind(dim = -2)
|
||||
return torch.cat((-x2, x1), dim = -1)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||
out_dtype = t.dtype
|
||||
|
||||
# cast to float32 if necessary for numerical stability
|
||||
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||
freqs = freqs[-seq_len:, :]
|
||||
|
||||
if t.ndim == 4 and freqs.ndim == 3:
|
||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||
|
||||
# partial rotary embeddings, Wang et al. GPT-J
|
||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
|
||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||
|
||||
return torch.cat((t, t_unrotated), dim = -1)
|
||||
|
||||
# norms
|
||||
class DynamicTanh(nn.Module):
|
||||
def __init__(self, dim, init_alpha=10.0):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.tanh(self.alpha * x)
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
class RunningInstanceNorm(nn.Module):
|
||||
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
||||
super().__init__()
|
||||
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
||||
self.register_buffer("running_std", torch.ones(1,1,dim))
|
||||
self.saturate = saturate
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.dim = dim
|
||||
self.trainable_gain = trainable_gain
|
||||
if self.trainable_gain:
|
||||
self.gain = nn.Parameter(torch.ones(1))
|
||||
|
||||
def _update_stats(self, x):
|
||||
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
||||
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
self._update_stats(x)
|
||||
x = (x - self.running_mean) / self.running_std
|
||||
if self.saturate:
|
||||
x = torch.asinh(x)
|
||||
if self.trainable_gain:
|
||||
x = x * self.gain
|
||||
return x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
||||
"""
|
||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("gamma", torch.ones(dim))
|
||||
else:
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
self.eps = eps
|
||||
|
||||
self.force_fp32 = force_fp32
|
||||
|
||||
def forward(self, x):
|
||||
if not self.force_fp32:
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
||||
else:
|
||||
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
||||
return output.to(x.dtype)
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_val = 1e-5):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.full([dim], init_val))
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
activation: Callable,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
else:
|
||||
x = self.proj(x)
|
||||
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * self.act(gate)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out = None,
|
||||
mult = 4,
|
||||
no_bias = False,
|
||||
glu = True,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
zero_init_output = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
# Default to SwiGLU
|
||||
|
||||
activation = nn.SiLU()
|
||||
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
|
||||
if glu:
|
||||
linear_in = GLU(dim, inner_dim, activation)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
||||
|
||||
# init last linear layer to 0
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(linear_out.weight)
|
||||
if not no_bias:
|
||||
nn.init.zeros_(linear_out.bias)
|
||||
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
dim_context = None,
|
||||
causal = False,
|
||||
zero_init_output=True,
|
||||
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
||||
differential = False,
|
||||
feat_scale = False
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.differential = differential
|
||||
|
||||
dim_kv = dim_context if dim_context is not None else dim
|
||||
|
||||
self.num_heads = dim // dim_heads
|
||||
self.kv_heads = dim_kv // dim_heads
|
||||
|
||||
if dim_context is not None:
|
||||
if differential:
|
||||
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
||||
else:
|
||||
self.to_q = nn.Linear(dim, dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
||||
else:
|
||||
if differential:
|
||||
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
||||
else:
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
||||
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
||||
|
||||
self.qk_norm = qk_norm
|
||||
if self.qk_norm == "ln":
|
||||
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
elif self.qk_norm == 'rns':
|
||||
self.q_norm = nn.RMSNorm(dim_heads)
|
||||
self.k_norm = nn.RMSNorm(dim_heads)
|
||||
elif self.qk_norm == 'dyt':
|
||||
self.q_norm = DynamicTanh(dim_heads)
|
||||
self.k_norm = DynamicTanh(dim_heads)
|
||||
|
||||
self.sdp_kwargs = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
enable_mem_efficient = True
|
||||
)
|
||||
|
||||
self.feat_scale = feat_scale
|
||||
|
||||
if self.feat_scale:
|
||||
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
self.causal = causal
|
||||
|
||||
@compile
|
||||
def apply_qk_layernorm(self, q, k):
|
||||
q_type = q.dtype
|
||||
k_type = k.dtype
|
||||
q = self.q_norm(q).to(q_type)
|
||||
k = self.k_norm(k).to(k_type)
|
||||
return q, k
|
||||
|
||||
|
||||
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
||||
|
||||
if self.num_heads != self.kv_heads:
|
||||
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
||||
heads_per_kv_head = self.num_heads // self.kv_heads
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
flash_attn_available = HAS_FLASH_ATTN
|
||||
|
||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||
flex_attention_block_mask = None
|
||||
flex_attention_score_mod = None
|
||||
|
||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention is not available in this build. "
|
||||
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
||||
)
|
||||
elif flash_attn_available:
|
||||
fa_dtype_in = q.dtype
|
||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||
|
||||
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
||||
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
||||
|
||||
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
||||
|
||||
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
||||
else:
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
||||
return out
|
||||
|
||||
|
||||
#@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None,
|
||||
flex_attention_block_mask = None,
|
||||
flex_attention_score_mod = None,
|
||||
flash_attn_sliding_window = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
kv_input = context if has_context else x
|
||||
|
||||
if hasattr(self, 'to_q'):
|
||||
# Use separate linear projections for q and k/v
|
||||
if self.differential:
|
||||
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
||||
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
||||
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||
else:
|
||||
# Use fused linear projection
|
||||
if self.differential:
|
||||
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
||||
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# Normalize q and k for cosine sim attention
|
||||
if self.qk_norm == "l2":
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
elif self.qk_norm != "none":
|
||||
q, k = self.apply_qk_layernorm(q, k)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
freqs, _ = rotary_pos_emb
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
freqs = freqs.to(torch.float32)
|
||||
if q.shape[-2] >= k.shape[-2]:
|
||||
ratio = q.shape[-2] / k.shape[-2]
|
||||
q_freqs, k_freqs = freqs, ratio * freqs
|
||||
else:
|
||||
ratio = k.shape[-2] / q.shape[-2]
|
||||
q_freqs, k_freqs = ratio * freqs, freqs
|
||||
q = apply_rotary_pos_emb(q, q_freqs)
|
||||
k = apply_rotary_pos_emb(k, k_freqs)
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim = 1)
|
||||
k, k_diff = k.unbind(dim = 1)
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(out, ' b h n d -> b n (h d)')
|
||||
|
||||
# Communicate between heads
|
||||
|
||||
# with autocast(enabled = False):
|
||||
# out_dtype = out.dtype
|
||||
# out = out.to(torch.float32)
|
||||
# out = self.to_out(out).to(out_dtype)
|
||||
out = self.to_out(out)
|
||||
|
||||
if self.feat_scale:
|
||||
out_dc = out.mean(dim=-2, keepdim=True)
|
||||
out_hf = out - out_dc
|
||||
|
||||
# Selectively modulate DC and high frequency components
|
||||
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
||||
|
||||
return out
|
||||
|
||||
class ConformerModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
norm_kwargs = {},
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
self.glu = GLU(dim, dim, nn.SiLU())
|
||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||
self.swish = nn.SiLU()
|
||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
#@compile
|
||||
def forward(self, x):
|
||||
x = self.in_norm(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.glu(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.depthwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.mid_norm(x)
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
cross_attend = False,
|
||||
dim_context = None,
|
||||
global_cond_dim = None,
|
||||
causal = False,
|
||||
zero_init_branch_outputs = True,
|
||||
conformer = False,
|
||||
layer_ix = -1,
|
||||
remove_norms = False,
|
||||
add_rope = False,
|
||||
layer_scale = False,
|
||||
use_sync_block_film = False,
|
||||
attn_kwargs = {},
|
||||
ff_kwargs = {},
|
||||
norm_kwargs = {}
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = min(dim_heads,dim)
|
||||
self.cross_attend = cross_attend
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
if layer_scale and zero_init_branch_outputs:
|
||||
zero_init_branch_outputs = False
|
||||
|
||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
|
||||
self.add_rope = add_rope
|
||||
|
||||
self.self_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
if cross_attend:
|
||||
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.cross_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
dim_context=dim_context,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
||||
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.layer_ix = layer_ix
|
||||
|
||||
self.conformer = None
|
||||
if conformer:
|
||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
||||
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.global_cond_dim = global_cond_dim
|
||||
if global_cond_dim is not None:
|
||||
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
||||
|
||||
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
||||
|
||||
if use_sync_block_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
|
||||
@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
global_cond=None,
|
||||
rotary_pos_emb = None,
|
||||
self_attention_block_mask = None,
|
||||
self_attention_score_mod = None,
|
||||
cross_attention_block_mask = None,
|
||||
cross_attention_score_mod = None,
|
||||
self_attention_flash_sliding_window = None,
|
||||
cross_attention_flash_sliding_window = None,
|
||||
sync_cond = None,
|
||||
prepend_length=0
|
||||
):
|
||||
if rotary_pos_emb is None and self.add_rope:
|
||||
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
||||
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
||||
else:
|
||||
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
||||
|
||||
# self-attention with adaLN
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = self.self_attn_scale(x)
|
||||
x = x + residual
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
# feedforward with adaLN
|
||||
residual = x
|
||||
x = self.ff_norm(x)
|
||||
x = x * (1 + scale_ff) + shift_ff
|
||||
x = self.ff(x)
|
||||
x = x * torch.sigmoid(1 - gate_ff)
|
||||
x = self.ff_scale(x)
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
prepend_part = x[:, :prepend_length, :]
|
||||
audio_part = x[:, prepend_length:, :]
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
modulated_audio_part = audio_part * (1 + scale) + shift
|
||||
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
||||
|
||||
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
||||
return x
|
||||
|
||||
class ContinuousTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
*,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
dim_heads = 64,
|
||||
cross_attend=False,
|
||||
cond_token_dim=None,
|
||||
pre_cross_attn_ix=-1,
|
||||
final_cross_attn_ix=-1,
|
||||
global_cond_dim=None,
|
||||
causal=False,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_outputs=True,
|
||||
conformer=False,
|
||||
use_sinusoidal_emb=False,
|
||||
use_abs_pos_emb=False,
|
||||
abs_pos_emb_max_length=10000,
|
||||
num_memory_tokens=0,
|
||||
sliding_window=None,
|
||||
use_mlp=False,
|
||||
use_add_norm=False,
|
||||
use_gated=False,
|
||||
use_final_layer=False,
|
||||
use_zeros=False,
|
||||
use_conv=False,
|
||||
use_fusion_mlp=False,
|
||||
use_film=False,
|
||||
use_sync_film=False,
|
||||
use_sync_gated=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.causal = causal
|
||||
self.layers = nn.ModuleList([])
|
||||
if use_mlp:
|
||||
self.project_in = nn.Sequential(
|
||||
nn.Linear(dim_in, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim, bias=False)
|
||||
)
|
||||
else:
|
||||
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
||||
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
||||
self.video_temporal_conv = None
|
||||
self.audio_temporal_conv = None
|
||||
self.fusion_mlp = None
|
||||
if use_conv:
|
||||
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
if use_fusion_mlp:
|
||||
self.fusion_mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||
if use_sinusoidal_emb:
|
||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
||||
|
||||
self.adaLN_modulation = None
|
||||
if global_cond_dim is not None:
|
||||
if use_final_layer:
|
||||
self.norm_final = LayerNorm(dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
dim, 2 * dim, bias=True
|
||||
),
|
||||
)
|
||||
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.project_out.weight, 0)
|
||||
self.global_cond_embedder = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6)
|
||||
)
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
||||
|
||||
self.final_cross_attn_ix = final_cross_attn_ix
|
||||
self.use_gated = use_gated
|
||||
self.use_film = use_film
|
||||
self.use_add_norm = use_add_norm
|
||||
if self.use_add_norm:
|
||||
self.add_norm = nn.LayerNorm(dim)
|
||||
if use_gated:
|
||||
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
||||
|
||||
if use_film:
|
||||
self.film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.film_generator = None
|
||||
|
||||
if use_sync_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.sync_film_generator = None
|
||||
if use_sync_gated:
|
||||
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
else:
|
||||
self.sync_gate = None
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
for i in range(depth):
|
||||
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
||||
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
cross_attend = should_cross_attend,
|
||||
dim_context = cond_token_dim,
|
||||
global_cond_dim = global_cond_dim,
|
||||
causal = causal,
|
||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||
conformer=conformer,
|
||||
layer_ix=i,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
prepend_embeds = None,
|
||||
prepend_mask = None,
|
||||
add_cond = None,
|
||||
sync_cond = None,
|
||||
global_cond = None,
|
||||
return_info = False,
|
||||
use_checkpointing = True,
|
||||
exit_layer_ix = None,
|
||||
video_dropout_prob = 0.0,
|
||||
**kwargs
|
||||
):
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
|
||||
prepend_length = 0
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
x = self.project_in(x)
|
||||
if add_cond is not None:
|
||||
if self.use_gated:
|
||||
gate = torch.sigmoid(self.gate)
|
||||
x = x + gate * add_cond
|
||||
elif self.use_film:
|
||||
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
else:
|
||||
x = x + add_cond
|
||||
|
||||
if self.use_add_norm:
|
||||
x = self.add_norm(x)
|
||||
if self.fusion_mlp is not None:
|
||||
x = self.fusion_mlp(x)
|
||||
|
||||
if sync_cond is not None:
|
||||
# Resample sync_cond to match audio sequence length if needed
|
||||
if sync_cond.shape[1] != x.shape[1]:
|
||||
sync_cond = torch.nn.functional.interpolate(
|
||||
sync_cond.transpose(1, 2), size=x.shape[1],
|
||||
mode='linear', align_corners=False,
|
||||
).transpose(1, 2)
|
||||
if self.sync_film_generator is not None:
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
elif self.sync_gate is not None:
|
||||
gate_value = torch.sigmoid(self.sync_gate)
|
||||
x = x + gate_value * sync_cond
|
||||
# else:
|
||||
# x = x + sync_cond
|
||||
|
||||
if prepend_embeds is not None:
|
||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||
|
||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||
|
||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||
|
||||
if self.num_memory_tokens > 0:
|
||||
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
||||
x = torch.cat((memory_tokens, x), dim=1)
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
if global_cond is not None and self.global_cond_embedder is not None:
|
||||
global_cond_embed = self.global_cond_embedder(global_cond)
|
||||
else:
|
||||
global_cond_embed = global_cond
|
||||
# Iterate over the transformer layers
|
||||
for layer_ix, layer in enumerate(self.layers):
|
||||
if use_checkpointing:
|
||||
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
|
||||
if return_info:
|
||||
info["hidden_states"].append(x)
|
||||
|
||||
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
if global_cond is not None and self.adaLN_modulation is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
global_cond = global_cond.unsqueeze(1)
|
||||
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.project_out(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
@@ -1,180 +0,0 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
|
||||
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def load_ckpt_state_dict(ckpt_path, prefix=None):
|
||||
if ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
# 过滤特定前缀的state_dict
|
||||
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
def remove_weight_norm_from_model(model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight"):
|
||||
remove_weight_norm(module)
|
||||
|
||||
return model
|
||||
|
||||
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
||||
# License can be found in LICENSES/LICENSE_META.txt
|
||||
|
||||
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor containing probabilities.
|
||||
num_samples (int): Number of samples to draw.
|
||||
replacement (bool): Whether to draw with replacement or not.
|
||||
Keywords args:
|
||||
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||
Returns:
|
||||
torch.Tensor: Last dimension contains num_samples indices
|
||||
sampled from the multinomial probability distribution
|
||||
located in the last dimension of tensor input.
|
||||
"""
|
||||
|
||||
if num_samples == 1:
|
||||
q = torch.empty_like(input).exponential_(1, generator=generator)
|
||||
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
||||
|
||||
input_ = input.reshape(-1, input.shape[-1])
|
||||
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||
return output
|
||||
|
||||
|
||||
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
k (int): The k in “top-k”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||
min_value_top_k = top_k_value[..., [-1]]
|
||||
probs *= (probs >= min_value_top_k).float()
|
||||
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs, num_samples=1)
|
||||
return next_token
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
p (int): The p in “top-p”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort *= (~mask).float()
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
def next_power_of_two(n):
|
||||
return 2 ** (n - 1).bit_length()
|
||||
|
||||
def next_multiple_of_64(n):
|
||||
return ((n + 63) // 64) * 64
|
||||
|
||||
|
||||
# mask construction helpers
|
||||
|
||||
def mask_from_start_end_indices(
|
||||
seq_len: int,
|
||||
start: Tensor,
|
||||
end: Tensor
|
||||
):
|
||||
assert start.shape == end.shape
|
||||
device = start.device
|
||||
|
||||
seq = torch.arange(seq_len, device = device, dtype = torch.long)
|
||||
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
|
||||
seq = seq.expand(*start.shape, seq_len)
|
||||
|
||||
mask = seq >= start[..., None].long()
|
||||
mask &= seq < end[..., None].long()
|
||||
return mask
|
||||
|
||||
def mask_from_frac_lengths(
|
||||
seq_len: int,
|
||||
frac_lengths: Tensor
|
||||
):
|
||||
device = frac_lengths.device
|
||||
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
|
||||
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
|
||||
start = (max_start * rand).clamp(min = 0)
|
||||
end = start + lengths
|
||||
|
||||
return mask_from_start_end_indices(seq_len, start, end)
|
||||
|
||||
def _build_spline(video_feat, video_t, target_t):
|
||||
# 三次样条插值核心实现
|
||||
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
|
||||
spline = NaturalCubicSpline(coeffs)
|
||||
return spline.evaluate(target_t).permute(0,2,1)
|
||||
|
||||
def resample(video_feat, audio_latent):
|
||||
"""
|
||||
9s
|
||||
video_feat: [B, 72, D]
|
||||
audio_latent: [B, D', 194] or int
|
||||
"""
|
||||
B, Tv, D = video_feat.shape
|
||||
|
||||
if isinstance(audio_latent, torch.Tensor):
|
||||
# audio_latent is a tensor
|
||||
if audio_latent.shape[1] != 64:
|
||||
Ta = audio_latent.shape[1]
|
||||
else:
|
||||
Ta = audio_latent.shape[2]
|
||||
elif isinstance(audio_latent, int):
|
||||
# audio_latent is an int
|
||||
Ta = audio_latent
|
||||
else:
|
||||
raise TypeError("audio_latent must be either a tensor or an int")
|
||||
|
||||
# 构建时间戳 (关键改进点)
|
||||
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
|
||||
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
|
||||
|
||||
# 三维化处理 (Batch, Feature, Time)
|
||||
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
|
||||
|
||||
# 三次样条插值
|
||||
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
||||
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
import os
|
||||
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
|
||||
if enable_torch_compile:
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
return function
|
||||
@@ -1,12 +1,5 @@
|
||||
einops>=0.7.0
|
||||
einops-exts
|
||||
safetensors
|
||||
huggingface_hub
|
||||
transformers>=4.52.3
|
||||
k-diffusion>=0.1.1
|
||||
alias-free-torch
|
||||
descript-audio-codec
|
||||
vector-quantize-pytorch
|
||||
scipy
|
||||
tqdm
|
||||
torchaudio
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
name: prismaudio-extract
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- pip
|
||||
- ffmpeg<7
|
||||
- pip:
|
||||
- torch>=2.6.0
|
||||
- torchaudio>=2.6.0
|
||||
- torchvision>=0.21.0
|
||||
- tensorflow-cpu==2.15.0
|
||||
- jax
|
||||
- jaxlib
|
||||
- transformers>=4.52.3
|
||||
- decord
|
||||
- einops>=0.7.0
|
||||
- numpy
|
||||
- mediapy
|
||||
- git+https://github.com/google-deepmind/videoprism.git
|
||||
@@ -1,170 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone PrismAudio feature extraction script.
|
||||
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
|
||||
|
||||
Usage:
|
||||
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PLUGIN_DIR not in sys.path:
|
||||
sys.path.insert(0, _PLUGIN_DIR)
|
||||
|
||||
|
||||
def _step(n, total, label):
|
||||
"""Print step header and return start time."""
|
||||
print(f"[extract] Step {n}/{total} — {label}...", flush=True)
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def _done(t0, extra=""):
|
||||
elapsed = time.perf_counter() - t0
|
||||
suffix = f" {extra}" if extra else ""
|
||||
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
t_total = time.perf_counter()
|
||||
|
||||
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
||||
parser.add_argument("--video", required=True, help="Path to input video")
|
||||
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
||||
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
||||
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||
parser.add_argument("--clip_size", type=int, default=288)
|
||||
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||
parser.add_argument("--sync_size", type=int, default=224)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"[extract] Python : {sys.executable}", flush=True)
|
||||
print(f"[extract] Video : {args.video}", flush=True)
|
||||
print(f"[extract] Output : {args.output}", flush=True)
|
||||
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
|
||||
|
||||
if not os.path.exists(args.video):
|
||||
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(1, 6, "importing dependencies")
|
||||
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||
import torchvision.transforms as T
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
|
||||
feat_utils = FeaturesUtils(
|
||||
vae_config_path=args.vae_config,
|
||||
synchformer_ckpt=args.synchformer_ckpt,
|
||||
device=device,
|
||||
)
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(3, 6, "reading and preprocessing video")
|
||||
if args.video.endswith(".npy"):
|
||||
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
||||
fps = args.source_fps
|
||||
total_frames = all_frames.shape[0]
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = all_frames[clip_indices]
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = all_frames[sync_indices]
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
else:
|
||||
from decord import VideoReader, cpu
|
||||
vr = VideoReader(args.video, ctx=cpu(0))
|
||||
fps = vr.get_avg_fps()
|
||||
total_frames = len(vr)
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(max(1, int(duration * args.clip_fps)))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
# Synchformer processes in segments of 8; ensure at least 8 frames
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(max(8, int(duration * args.sync_fps)))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
|
||||
clip_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(args.clip_size),
|
||||
T.CenterCrop(args.clip_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||
|
||||
sync_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(args.sync_size),
|
||||
T.CenterCrop(args.sync_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(4, 6, "encoding text with T5-Gemma")
|
||||
text_features = feat_utils.encode_t5_text([args.cot_text])
|
||||
_done(t0, f"shape={tuple(text_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(5, 6, "encoding video with VideoPrism")
|
||||
global_video_features, video_features, global_text_features = \
|
||||
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
||||
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(6, 6, "encoding video with Synchformer")
|
||||
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
||||
_done(t0, f"shape={tuple(sync_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = time.perf_counter()
|
||||
print(f"[extract] Saving features to {args.output} ...", flush=True)
|
||||
np.savez(
|
||||
args.output,
|
||||
video_features=video_features.cpu().float().numpy(),
|
||||
global_video_features=global_video_features.cpu().float().numpy(),
|
||||
text_features=text_features.cpu().float().numpy(),
|
||||
global_text_features=global_text_features.cpu().float().numpy(),
|
||||
sync_features=sync_features.cpu().float().numpy(),
|
||||
caption_cot=args.cot_text,
|
||||
duration=duration,
|
||||
)
|
||||
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,44 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Install the PrismAudio feature-extraction environment using pip venv.
|
||||
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/install_extract_env.sh [/path/to/venv]
|
||||
#
|
||||
# Default venv path: /opt/prismaudio-extract
|
||||
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
|
||||
# <venv>/bin/python (Linux/Mac)
|
||||
# <venv>\Scripts\python.exe (Windows)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
VENV_DIR="${1:-/opt/prismaudio-extract}"
|
||||
|
||||
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
|
||||
python3 -m venv "${VENV_DIR}"
|
||||
|
||||
PIP="${VENV_DIR}/bin/pip"
|
||||
|
||||
echo "[PrismAudio] Upgrading pip..."
|
||||
"${PIP}" install --upgrade pip
|
||||
|
||||
echo "[PrismAudio] Installing PyTorch stack..."
|
||||
"${PIP}" install torch torchaudio torchvision
|
||||
|
||||
echo "[PrismAudio] Installing feature-extraction dependencies..."
|
||||
"${PIP}" install \
|
||||
"tensorflow-cpu>=2.16.0" \
|
||||
"jax[cpu]" \
|
||||
"jaxlib" \
|
||||
"transformers" \
|
||||
"decord" \
|
||||
"einops" \
|
||||
"numpy" \
|
||||
"mediapy"
|
||||
|
||||
echo "[PrismAudio] Installing VideoPrism..."
|
||||
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
|
||||
|
||||
echo ""
|
||||
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
|
||||
echo " ${VENV_DIR}/bin/python"
|
||||
@@ -0,0 +1,3 @@
|
||||
# Vendored from https://github.com/jnwnlee/selva
|
||||
# Pinned commit: d7d40a992aab58e7cf246055681a657e5d8b4a4d
|
||||
# Imports rewritten from selva.* → selva_core.*
|
||||
@@ -0,0 +1,190 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from av import AudioFrame
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@dataclass
|
||||
class VideoInfo:
|
||||
duration_sec: float
|
||||
fps: Fraction
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
all_frames: Optional[list[np.ndarray]]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.all_frames[0].shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.all_frames[0].shape[1]
|
||||
|
||||
@classmethod
|
||||
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
||||
fps: Fraction) -> 'VideoInfo':
|
||||
num_frames = int(duration_sec * fps)
|
||||
all_frames = [image_info.original_frame] * num_frames
|
||||
return cls(duration_sec=duration_sec,
|
||||
fps=fps,
|
||||
clip_frames=image_info.clip_frames,
|
||||
sync_frames=image_info.sync_frames,
|
||||
all_frames=all_frames)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageInfo:
|
||||
clip_frames: torch.Tensor
|
||||
sync_frames: torch.Tensor
|
||||
original_frame: Optional[np.ndarray]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.original_frame.shape[0]
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.original_frame.shape[1]
|
||||
|
||||
|
||||
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
||||
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
||||
output_frames = [[] for _ in list_of_fps]
|
||||
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
||||
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
||||
all_frames = []
|
||||
|
||||
# container = av.open(video_path)
|
||||
with av.open(video_path) as container:
|
||||
stream = container.streams.video[0]
|
||||
fps = stream.guessed_rate
|
||||
stream.thread_type = 'AUTO'
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
frame_time = frame.time
|
||||
if frame_time < start_sec:
|
||||
continue
|
||||
if frame_time > end_sec:
|
||||
break
|
||||
|
||||
frame_np = None
|
||||
if need_all_frames:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
all_frames.append(frame_np)
|
||||
|
||||
for i, _ in enumerate(list_of_fps):
|
||||
this_time = frame_time
|
||||
while this_time >= next_frame_time_for_each_fps[i]:
|
||||
if frame_np is None:
|
||||
frame_np = frame.to_ndarray(format='rgb24')
|
||||
|
||||
output_frames[i].append(frame_np)
|
||||
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
||||
|
||||
output_frames = [np.stack(frames) for frames in output_frames]
|
||||
return output_frames, all_frames, fps
|
||||
|
||||
|
||||
def normalize_video_chunk(video_chunk: torch.Tensor,
|
||||
expected_length: int,
|
||||
*,
|
||||
n_tolerance_frame: int = 1,
|
||||
desc: str = "") \
|
||||
-> torch.Tensor:
|
||||
# video_chunk: [T, H, W, C]
|
||||
if video_chunk.shape[0] < expected_length:
|
||||
if expected_length - video_chunk.shape[0] <= n_tolerance_frame:
|
||||
# copy the last frame to make it the right length
|
||||
log.warning(f'Video too short {desc}, padding {expected_length - video_chunk.shape[0]} frames with the last frame')
|
||||
video_chunk = torch.cat([video_chunk, video_chunk[-1:].repeat(expected_length - video_chunk.shape[0], 1, 1, 1)])
|
||||
assert video_chunk.shape[0] == expected_length
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Video too short {desc}, expected {expected_length}, got {video_chunk.shape[0]}'
|
||||
)
|
||||
video_chunk = video_chunk[:expected_length]
|
||||
if video_chunk.shape[0] != expected_length:
|
||||
raise RuntimeError(f'Video wrong length {desc}, '
|
||||
f'expected {expected_length}, '
|
||||
f'got {video_chunk.shape[0]}')
|
||||
|
||||
return video_chunk
|
||||
|
||||
|
||||
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
||||
sampling_rate: int):
|
||||
container = av.open(output_path, 'w')
|
||||
output_video_stream = container.add_stream('h264', video_info.fps)
|
||||
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
||||
output_video_stream.width = video_info.width
|
||||
output_video_stream.height = video_info.height
|
||||
output_video_stream.pix_fmt = 'yuv420p'
|
||||
|
||||
output_audio_stream = container.add_stream('aac', sampling_rate)
|
||||
|
||||
# encode video
|
||||
for image in video_info.all_frames:
|
||||
image = av.VideoFrame.from_ndarray(image)
|
||||
packet = output_video_stream.encode(image)
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_video_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
container.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
||||
"""
|
||||
NOTE: I don't think we can get the exact video duration right without re-encoding
|
||||
so we are not using this but keeping it here for reference
|
||||
"""
|
||||
video = av.open(video_path)
|
||||
output = av.open(output_path, 'w')
|
||||
input_video_stream = video.streams.video[0]
|
||||
output_video_stream = output.add_stream(template=input_video_stream)
|
||||
output_audio_stream = output.add_stream('aac', sampling_rate)
|
||||
|
||||
duration_sec = audio.shape[-1] / sampling_rate
|
||||
|
||||
for packet in video.demux(input_video_stream):
|
||||
# We need to skip the "flushing" packets that `demux` generates.
|
||||
if packet.dts is None:
|
||||
continue
|
||||
# We need to assign the packet to the new stream.
|
||||
packet.stream = output_video_stream
|
||||
output.mux(packet)
|
||||
|
||||
# convert float tensor audio to numpy array
|
||||
audio_np = audio.numpy().astype(np.float32)
|
||||
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
||||
audio_frame.sample_rate = sampling_rate
|
||||
|
||||
for packet in output_audio_stream.encode(audio_frame):
|
||||
output.mux(packet)
|
||||
|
||||
for packet in output_audio_stream.encode():
|
||||
output.mux(packet)
|
||||
|
||||
video.close()
|
||||
output.close()
|
||||
|
||||
output.close()
|
||||
@@ -0,0 +1,227 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from selva_core.data.vgg_sound import VGGSound
|
||||
from selva_core.data.eval.eval_video_dataset import VGGSound as VGGSoundEval
|
||||
from selva_core.data.eval.eval_video_dataset import InferenceVideoData, VGGMonoAudioBench
|
||||
from selva_core.data.eval.audiocaps import AudioCapsData
|
||||
from selva_core.data.mm_dataset import MultiModalDataset
|
||||
from selva_core.data.mixup import DataMixupCollate
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
# Re-seed randomness every time we start a worker
|
||||
def worker_init_fn(worker_id: int):
|
||||
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
||||
|
||||
|
||||
def load_video_data(cfg: DictConfig, data_cfg: DictConfig, normalize_audio: bool = False,
|
||||
) -> Dataset:
|
||||
dataset = VGGSound(root=data_cfg.root,
|
||||
tsv_path=data_cfg.subset_name,
|
||||
sample_rate=16_000,
|
||||
duration_sec=8.0,
|
||||
normalize_audio=normalize_audio,
|
||||
mmap_dir=data_cfg.memmap_dir,
|
||||
tsv_tsynch_path=data_cfg.tsv_tsynch,
|
||||
mmap_tsync_dir=data_cfg.memmap_dir_tsynch,
|
||||
data_dim=cfg.data_dim
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
||||
raise NotImplementedError('Audio data loading is not implemented yet')
|
||||
|
||||
|
||||
def setup_training_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.mini_train:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=True)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
if cfg.example_train:
|
||||
video = load_video_data(cfg, cfg.data.Example_video, normalize_audio=True)
|
||||
dataset = MultiModalDataset([video], [])
|
||||
else:
|
||||
vgg = load_video_data(cfg, cfg.data.VGGSound, normalize_audio=True)
|
||||
# load the largest one first
|
||||
# you can add more video/audio data upon demand, such as
|
||||
# clotho = load_audio_data(cfg, cfg.data.Clotho)
|
||||
dataset = MultiModalDataset([vgg], [])
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_params = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_params)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_test_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False, split='test')
|
||||
elif cfg.dataset.startswith('vggsound'):
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_test, normalize_audio=False, split='test')
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown dataset for test: {cfg.dataset}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
sampler, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, sampler, loader
|
||||
|
||||
|
||||
def setup_val_datasets(cfg: DictConfig,
|
||||
generator: torch.Generator,
|
||||
) -> tuple[Dataset, DataLoader, DataLoader]:
|
||||
if cfg.example_train:
|
||||
dataset = load_video_data(cfg, cfg.data.Example_video, normalize_audio=False)
|
||||
else:
|
||||
dataset = load_video_data(cfg, cfg.data.VGGSound_val, normalize_audio=False)
|
||||
|
||||
val_batch_size = cfg.batch_size
|
||||
val_eval_batch_size = cfg.eval_batch_size
|
||||
num_workers = cfg.get('num_workers_val', cfg.num_workers)
|
||||
pin_memory = cfg.pin_memory
|
||||
|
||||
if cfg.mixup.domain == 'data':
|
||||
mixup_config = cfg.mixup.params
|
||||
collate_fn = DataMixupCollate(generator=generator,
|
||||
**mixup_config)
|
||||
else:
|
||||
collate_fn = None
|
||||
|
||||
_, val_loader = construct_loader(dataset,
|
||||
val_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
_, eval_loader = construct_loader(dataset,
|
||||
val_eval_batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=collate_fn)
|
||||
|
||||
return dataset, val_loader, eval_loader
|
||||
|
||||
|
||||
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
||||
if dataset_name.startswith('audiocaps_full'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps_full.audio_path,
|
||||
cfg.eval_data.audiocaps_full.csv_path)
|
||||
elif dataset_name.startswith('audiocaps'):
|
||||
dataset = AudioCapsData(cfg.eval_data.audiocaps.audio_path,
|
||||
cfg.eval_data.audiocaps.csv_path)
|
||||
elif dataset_name.startswith('vggsound'):
|
||||
dataset = VGGSound(cfg.eval_data.vggsound.video_path,
|
||||
cfg.eval_data.vggsound.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name.startswith('infer_video'):
|
||||
dataset = InferenceVideoData(cfg.eval_data.infer_video.video_path,
|
||||
cfg.eval_data.infer_video.jsonl_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
cfg.batch_size = 1
|
||||
elif dataset_name.startswith('example_video'):
|
||||
dataset = VGGSoundEval(cfg.eval_data.Example_video.video_path,
|
||||
cfg.eval_data.Example_video.csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
elif dataset_name in ['vgg_monoaudio_intra', 'vgg_monoaudio_inter']:
|
||||
dataset = VGGMonoAudioBench(cfg.eval_data[dataset_name].video_path,
|
||||
cfg.eval_data[dataset_name].csv_path,
|
||||
duration_sec=cfg.duration_s)
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
||||
|
||||
batch_size = cfg.batch_size
|
||||
num_workers = cfg.num_workers
|
||||
pin_memory = cfg.pin_memory
|
||||
_, loader = construct_loader(dataset,
|
||||
batch_size,
|
||||
num_workers,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
pin_memory=pin_memory,
|
||||
error_avoidance=True)
|
||||
return dataset, loader
|
||||
|
||||
|
||||
def error_avoidance_collate(batch):
|
||||
# Filter our None values
|
||||
batch = [item for item in batch if item is not None]
|
||||
if len(batch) == 0:
|
||||
return None
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
def construct_loader(dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
drop_last: bool = True,
|
||||
pin_memory: bool = False,
|
||||
error_avoidance: bool = False,
|
||||
collate_fn = None) -> tuple[DistributedSampler, DataLoader]:
|
||||
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
||||
train_loader = DataLoader(dataset,
|
||||
batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_init_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=num_workers > 0,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=error_avoidance_collate if error_avoidance else collate_fn)
|
||||
return train_sampler, train_loader
|
||||
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class AudioCapsData(Dataset):
|
||||
|
||||
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
||||
df = pd.read_csv(csv_path).to_dict(orient='records')
|
||||
|
||||
audio_files = sorted(os.listdir(audio_path))
|
||||
audio_files = set(
|
||||
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
||||
|
||||
self.data = []
|
||||
for row in df:
|
||||
self.data.append({
|
||||
'name': row['name'],
|
||||
'caption': row['caption'],
|
||||
})
|
||||
|
||||
self.audio_path = Path(audio_path)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@@ -0,0 +1,237 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
self.video_root = Path(video_root)
|
||||
self.duration_sec = duration_sec
|
||||
self.clip_video_required = clip_video_required
|
||||
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if self.clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
# to be implemented by subclasses
|
||||
self.captions = {}
|
||||
self.negative_captions = {}
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
caption = self.captions[video_id]
|
||||
negative_caption = self.negative_captions.get(video_id, None)
|
||||
|
||||
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
data = {
|
||||
'name': video_id,
|
||||
'caption': caption,
|
||||
'sync_video': sync_chunk,
|
||||
}
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
if negative_caption is not None:
|
||||
data['negative_caption'] = negative_caption
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions)
|
||||
|
||||
|
||||
class VGGSound(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
csv_path: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
self.captions = {}
|
||||
|
||||
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
||||
'split']).to_dict(orient='records')
|
||||
|
||||
videos_no_found = []
|
||||
for row in df:
|
||||
if row['split'] == 'test':
|
||||
start_sec = int(row['sec'])
|
||||
video_id = str(row['id'])
|
||||
# this is how our videos are named
|
||||
video_name = f'{video_id}_{start_sec:06d}'
|
||||
if video_name + '.mp4' not in videos:
|
||||
videos_no_found.append(video_name)
|
||||
continue
|
||||
|
||||
self.captions[video_name] = row['caption']
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
log.info(f'{len(self.captions)} useable videos found')
|
||||
if videos_no_found:
|
||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
||||
log.info(
|
||||
'A small amount is expected, as not all videos are still available on YouTube')
|
||||
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
|
||||
|
||||
class InferenceVideoData(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
jsonl_root: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 10.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.jsonl_root = Path(jsonl_root)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
videos = [v[:-4] for v in videos] # remove extensions
|
||||
self.captions = {}
|
||||
|
||||
for v in videos:
|
||||
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
||||
data = json.load(f)
|
||||
self.captions[v] = data['audio_prompt']
|
||||
self.negative_captions[v] = data.get('negative_audio_prompt', None)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
|
||||
self.videos = videos
|
||||
|
||||
|
||||
class VGGMonoAudioBench(VideoDataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_root: Union[str, Path],
|
||||
csv_path: Union[str, Path],
|
||||
*,
|
||||
duration_sec: float = 8.0,
|
||||
clip_video_required: bool = False,
|
||||
):
|
||||
super().__init__(video_root, duration_sec=duration_sec,
|
||||
clip_video_required=clip_video_required)
|
||||
self.video_root = Path(video_root)
|
||||
self.csv_path = Path(csv_path)
|
||||
|
||||
videos = sorted(os.listdir(self.video_root))
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
self.captions = {}
|
||||
self.negative_captions = {}
|
||||
|
||||
df = pd.read_csv(csv_path, header=0, usecols=['file_name', 'label', 'paired_label']
|
||||
).to_dict(orient='records')
|
||||
|
||||
videos_no_found = []
|
||||
for row in df:
|
||||
video_name = str(Path(row['file_name']).stem)
|
||||
if video_name + '.mp4' not in videos:
|
||||
videos_no_found.append(video_name)
|
||||
continue
|
||||
|
||||
self.captions[video_name] = row['label']
|
||||
self.negative_captions[video_name] = row['paired_label']
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {video_root}')
|
||||
log.info(f'{len(self.captions)} useable videos found')
|
||||
if videos_no_found:
|
||||
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}!')
|
||||
|
||||
self.videos = sorted(list(self.captions.keys()))
|
||||
@@ -0,0 +1,194 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VGGSound(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||
audio_required: bool = True,
|
||||
sample_rate: int = 16_000,
|
||||
duration_sec: float = 8.0,
|
||||
audio_samples: Optional[int] = None,
|
||||
normalize_audio: bool = False,
|
||||
clip_video_required: bool = True,
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.audio_required = audio_required
|
||||
if audio_required:
|
||||
self.normalize_audio = normalize_audio
|
||||
if audio_samples is None:
|
||||
self.audio_samples = int(sample_rate * duration_sec)
|
||||
else:
|
||||
self.audio_samples = audio_samples
|
||||
effective_duration = audio_samples / sample_rate
|
||||
# make sure the duration is close enough, within 15ms
|
||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||
self.clip_video_required = clip_video_required
|
||||
|
||||
videos = sorted(os.listdir(self.root))
|
||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||
self.labels = {}
|
||||
self.videos = []
|
||||
missing_videos = []
|
||||
|
||||
# read the tsv for subset information
|
||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
label = record['label']
|
||||
if id in videos:
|
||||
self.labels[id] = label
|
||||
self.videos.append(id)
|
||||
else:
|
||||
missing_videos.append(id)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {root}')
|
||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.duration_sec = duration_sec
|
||||
|
||||
if audio_required:
|
||||
self.expected_audio_length = self.audio_samples
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
if clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if clip_video_required:
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
if audio_required:
|
||||
self.resampler = {}
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
|
||||
label = self.labels[video_id]
|
||||
|
||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.audio_required:
|
||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.audio_required:
|
||||
audio_chunk = data_chunk[1]
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
# process audio
|
||||
if self.audio_required:
|
||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||
audio_chunk = audio_chunk.transpose(0, 1)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
if self.normalize_audio:
|
||||
abs_max = audio_chunk.abs().max()
|
||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||
if abs_max <= 1e-6:
|
||||
raise RuntimeError(f'Audio is silent {video_id}')
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||
raise RuntimeError(f'Audio too short {video_id}')
|
||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||
|
||||
data = {
|
||||
'id': video_id,
|
||||
'caption': label,
|
||||
'sync_video': sync_chunk,
|
||||
}
|
||||
|
||||
if self.audio_required:
|
||||
data['audio'] = audio_chunk
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
@@ -0,0 +1,129 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import open_clip
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
class WavTextClipsDataset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
captions_tsv: Union[str, Path],
|
||||
clips_tsv: Union[str, Path],
|
||||
sample_rate: int,
|
||||
num_samples: int,
|
||||
normalize_audio: bool = False,
|
||||
reject_silent: bool = False,
|
||||
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.sample_rate = sample_rate
|
||||
self.num_samples = num_samples
|
||||
self.normalize_audio = normalize_audio
|
||||
self.reject_silent = reject_silent
|
||||
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
||||
|
||||
audios = sorted(os.listdir(self.root))
|
||||
audios = set([
|
||||
Path(audio).stem for audio in audios
|
||||
if audio.endswith('.wav') or audio.endswith('.flac')
|
||||
])
|
||||
self.captions = {}
|
||||
|
||||
# read the caption tsv
|
||||
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
caption = record['caption']
|
||||
self.captions[id] = caption
|
||||
|
||||
# read the clip tsv
|
||||
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
||||
'id': str,
|
||||
'name': str
|
||||
}).to_dict('records')
|
||||
self.clips = []
|
||||
for record in df_list:
|
||||
record['id'] = record['id']
|
||||
record['name'] = record['name']
|
||||
id = record['id']
|
||||
name = record['name']
|
||||
record['caption'] = self.captions[name]
|
||||
self.clips.append(record)
|
||||
|
||||
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
||||
|
||||
self.resampler = {}
|
||||
|
||||
def __getitem__(self, idx: int) -> torch.Tensor:
|
||||
try:
|
||||
clip = self.clips[idx]
|
||||
audio_name = clip['name']
|
||||
audio_id = clip['id']
|
||||
caption = clip['caption']
|
||||
start_sample = clip['start_sample']
|
||||
end_sample = clip['end_sample']
|
||||
|
||||
audio_path = self.root / f'{audio_name}.flac'
|
||||
if not audio_path.exists():
|
||||
audio_path = self.root / f'{audio_name}.wav'
|
||||
assert audio_path.exists()
|
||||
|
||||
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
abs_max = audio_chunk.abs().max()
|
||||
if self.normalize_audio:
|
||||
audio_chunk = audio_chunk / abs_max * 0.95
|
||||
|
||||
if self.reject_silent and abs_max < 1e-6:
|
||||
log.warning(f'Rejecting silent audio')
|
||||
return None
|
||||
|
||||
audio_chunk = audio_chunk[start_sample:end_sample]
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.num_samples:
|
||||
raise ValueError('Audio is too short')
|
||||
audio_chunk = audio_chunk[:self.num_samples]
|
||||
|
||||
tokens = self.tokenizer([caption])[0]
|
||||
|
||||
output = {
|
||||
'waveform': audio_chunk,
|
||||
'id': audio_id,
|
||||
'caption': caption,
|
||||
'tokens': tokens,
|
||||
}
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
log.error(f'Error reading {audio_path}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
@@ -0,0 +1,338 @@
|
||||
""" Embedding Mixup
|
||||
Reference: https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/mixup.py
|
||||
"""
|
||||
from typing import Literal, Tuple, Union, List, Optional
|
||||
from functools import partial
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from selva_core.data.vgg_sound import _SYNC_SIZE
|
||||
|
||||
|
||||
class MixupBase:
|
||||
""" Base class for mixup on either data or feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both'],
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
self.modality = modality
|
||||
self.mixup_lambda:float = mixup_lambda
|
||||
self.mixup_alpha:float = mixup_alpha
|
||||
self.mix_prob:float = prob
|
||||
self.mode:str = mode
|
||||
self.eps:float = eps
|
||||
self.mixup_enabled:bool = True # set to false to disable mixing (intended to be set by train loop)
|
||||
if generator.device.type == 'cuda':
|
||||
self.generator_cuda = generator
|
||||
generator_seed = generator.initial_seed()
|
||||
self.generator = torch.Generator(device='cpu')
|
||||
self.generator.manual_seed(generator_seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
|
||||
if not (self.mixup_lambda >= 0. and self.mixup_lambda <= 1.):
|
||||
raise ValueError(f"mixup_lambda {self.mixup_lambda} should be in [0., 1.].")
|
||||
if not self.mixup_alpha >= 0.:
|
||||
raise ValueError(f"mixup_alpha {self.mixup_alpha} >= 0. should be true.")
|
||||
if (self.mixup_alpha > 0. and self.mixup_lambda < 1.) or (self.mixup_alpha == 0. and self.mixup_lambda == 1.):
|
||||
raise ValueError(f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true.")
|
||||
|
||||
def _params_per_elem(self, batch_size:int) -> np.ndarray:
|
||||
lam:np.ndarray = np.ones(batch_size, dtype=np.float32)
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam_mix = np.full(batch_size, self.mixup_lambda, dtype=np.float32)
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
# Use torch's beta distribution with generator
|
||||
lam_mix = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample([batch_size]).numpy().astype(np.float32).reshape(-1)
|
||||
else:
|
||||
assert False, f"One of mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
lam_mix[lam_mix < self.eps] = self.eps
|
||||
|
||||
# Use torch's random with generator for the random comparison
|
||||
rand_vals = torch.rand(batch_size, generator=self.generator).numpy()
|
||||
lam = np.where(rand_vals < self.mix_prob, lam_mix, lam)
|
||||
return lam
|
||||
|
||||
def _params_per_batch(self) -> float:
|
||||
lam:float = 1.
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_lambda < 1.: # constant lambda
|
||||
lam = self.mixup_lambda
|
||||
elif self.mixup_alpha > 0.: # sampled lambda
|
||||
lam = torch.distributions.Beta(
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
torch.tensor([self.mixup_alpha]),
|
||||
).sample().item()
|
||||
else:
|
||||
assert False, f"mixup_alpha {self.mixup_alpha} > 0., mixup_lambda {self.mixup_lambda} < 1. should be true."
|
||||
if lam < self.eps: lam = self.eps
|
||||
lam = float(lam)
|
||||
return lam
|
||||
|
||||
|
||||
class DataMixupCollate(MixupBase):
|
||||
""" Mixup video in data domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
|
||||
self.source_video_key= 'sync_video'
|
||||
self.source_audio_key = 'audio'
|
||||
self.target_video_key = 'sync_video_mixed'
|
||||
self.target_audio_key = 'audio_mixed'
|
||||
|
||||
if not mode == 'batch':
|
||||
raise ValueError(f"Mode {mode} is not supported for data domain.")
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
def _concat_video_frames(self, batch:list, target_key:str='sync_video_mixed', source_key:str='sync_video') -> float:
|
||||
# only batch mode supported
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
# Randomly choose between horizontal and vertical resizing using
|
||||
orig_size = int(lam * _SYNC_SIZE)
|
||||
is_horizontal = True # torch.rand(1, generator=self.generator).item() < 0.5
|
||||
if is_horizontal:
|
||||
# Horizontal resize
|
||||
resize_shape_orig = (_SYNC_SIZE, orig_size)
|
||||
resize_shape_pair = (_SYNC_SIZE, _SYNC_SIZE-orig_size)
|
||||
else:
|
||||
# Vertical resize
|
||||
resize_shape_orig = (orig_size, _SYNC_SIZE)
|
||||
resize_shape_pair = (_SYNC_SIZE-orig_size, _SYNC_SIZE)
|
||||
sync_resize_orig = v2.Compose([
|
||||
v2.Resize(resize_shape_orig, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
sync_resize_pair = v2.Compose([
|
||||
v2.Resize(resize_shape_pair, interpolation=v2.InterpolationMode.BICUBIC),
|
||||
])
|
||||
|
||||
batch_videos_orig = torch.stack([batch[i][source_key] for i in range(batch_size)], dim=0)
|
||||
batch_videos_pair = torch.stack([batch[batch_size - i - 1][source_key] for i in range(batch_size)], dim=0)
|
||||
# (B, T, C, H, W)
|
||||
# pass through resize, transform and concat
|
||||
batch_videos_orig = sync_resize_orig(batch_videos_orig)
|
||||
batch_videos_pair = sync_resize_pair(batch_videos_pair)
|
||||
batch_videos_concat = torch.cat((batch_videos_orig, batch_videos_pair), dim=-1 if is_horizontal else -2)
|
||||
batch_videos_concat = self.sync_transform(batch_videos_concat)
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch_videos_concat[i]
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
del batch_videos_orig, batch_videos_pair, sync_resize_orig, sync_resize_pair
|
||||
gc.collect()
|
||||
|
||||
return lam
|
||||
|
||||
def _mix_audio_samples(self, batch:list, target_key:str='audio_mixed', source_key:str='audio',
|
||||
normalize:bool = True) -> float:
|
||||
# assume source_key audios are normalized
|
||||
batch_size:int = len(batch)
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
if lam == 1.:
|
||||
# no mixup, just return
|
||||
for i in range(batch_size):
|
||||
batch[i][target_key] = batch[i][source_key]
|
||||
return lam
|
||||
|
||||
num_mixup = int(self.mix_prob * batch_size)
|
||||
for i in range(num_mixup):
|
||||
batch[i][target_key] = batch[i][source_key] * lam + batch[batch_size - i - 1][source_key] * (1 - lam)
|
||||
if normalize:
|
||||
source_abs_max = batch[i][source_key].abs().max()
|
||||
target_abs_max = batch[i][target_key].abs().max()
|
||||
batch[i][target_key] = batch[i][target_key] * (source_abs_max / target_abs_max)
|
||||
for i in range(num_mixup, batch_size):
|
||||
batch[i][target_key] = batch[i][source_key] # no mixup
|
||||
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:list, _=None) -> torch.tensor:
|
||||
batch_size:int = len(batch)
|
||||
assert batch_size % 2 == 0, f'Batch size {batch_size} should be even when using mixup'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
if self.modality == 'video' or self.modality == 'both':
|
||||
lam = self._concat_video_frames(batch, target_key=self.target_video_key, source_key=self.source_video_key)
|
||||
if self.modality == 'audio' or self.modality == 'both':
|
||||
# raise NotImplementedError('Audio mixup is not implemented yet.')
|
||||
lam = self._mix_audio_samples(batch, target_key=self.target_audio_key, source_key=self.source_audio_key)
|
||||
|
||||
return default_collate(batch)
|
||||
|
||||
|
||||
class FeatureMixup(MixupBase):
|
||||
""" Mixup video in feature domain.
|
||||
Applies different params to each element or whole batch.
|
||||
|
||||
Args:
|
||||
generator (Optional[torch.Generator]): Random number generator for reproducibility
|
||||
modality (Literal['video', 'audio', 'both']): Modality to apply mixup on.
|
||||
mixup_lambda (float): Mixup lambda value, mixup is active if in [0., 1.].
|
||||
mixup_alpha (float): Mixup alpha value, mixup is active if > 0.
|
||||
prob (float): Probability of applying mixup per batch or element
|
||||
mode (Literal['elem','pair','batch', 'half']): How to apply mixup params (per 'batch', 'pair' (pair of elements), 'elem' (element), 'half' (half batch))
|
||||
eps (float): Small epsilon value to avoid zero lambda
|
||||
"""
|
||||
def __init__(self, generator:torch.Generator,
|
||||
*,
|
||||
modality:Literal['video', 'audio', 'both']='video',
|
||||
mixup_lambda:float=0.5, mixup_alpha:float=1., prob:float=1.0,
|
||||
mode:Literal['elem','pair','batch', 'half']='batch',
|
||||
eps:float=0.05
|
||||
):
|
||||
super().__init__(generator, modality=modality,
|
||||
mixup_lambda=mixup_lambda, mixup_alpha=mixup_alpha, prob=prob,
|
||||
mode=mode, eps=eps)
|
||||
self.source_video_key= 'sync_f_vid_orig'
|
||||
self.source_audio_key = 'sync_f_aud_orig'
|
||||
self.target_video_key = 'sync_f_vid_mixed'
|
||||
self.target_audio_key = 'sync_f_aud_mixed'
|
||||
|
||||
def _mix_elem_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig'],
|
||||
half:bool=False) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
num_elem:int = batch_size // 2 if half else batch_size
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(num_elem))
|
||||
|
||||
indices = torch.arange(num_elem)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
if half:
|
||||
lam_batch = torch.cat((lam_batch, torch.ones(num_elem, dtype=lam_batch.dtype)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_pair_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> torch.tensor:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
batch_size:int = len(batch['id'])
|
||||
lam_batch:torch.tensor = torch.from_numpy(self._params_per_elem(batch_size // 2))
|
||||
|
||||
indices = torch.arange(batch_size // 2)
|
||||
mix_indices = batch_size - indices - 1
|
||||
mix_mask = lam_batch < 1
|
||||
active_indices = indices[mix_mask]
|
||||
active_mix_indices = mix_indices[mix_mask]
|
||||
active_lambdas = lam_batch[mix_mask].unsqueeze(1)
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
batch[target_key][active_indices] = (
|
||||
batch[source_key][active_indices] * active_lambdas +
|
||||
batch[source_key][active_mix_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][active_mix_indices] = (
|
||||
batch[source_key][active_mix_indices] * active_lambdas +
|
||||
batch[source_key][active_indices] * (1 - active_lambdas)
|
||||
)
|
||||
batch[target_key][~indices[mix_mask]] = batch[source_key][~indices[mix_mask]]
|
||||
batch[target_key][~mix_indices[mix_mask]] = batch[source_key][~mix_indices[mix_mask]]
|
||||
lam_batch = torch.cat((lam_batch, lam_batch.flip(0)))
|
||||
return lam_batch.unsqueeze(1)
|
||||
|
||||
def _mix_batch_collate(self, batch:dict,
|
||||
target_keys:List[str]=['sync_features_mixed'], source_keys:List[str]=['sync_features_orig']) -> float:
|
||||
assert len(target_keys) == len(source_keys), f"Length of target_keys {len(target_keys)} and source_keys {len(source_keys)} should be equal."
|
||||
lam:float = self._params_per_batch()
|
||||
|
||||
for target_key, source_key in zip(target_keys, source_keys):
|
||||
num_mixup = int(self.mix_prob * batch[source_key].shape[0])
|
||||
flipped_source = torch.flip(batch[source_key], dims=[0])
|
||||
batch[target_key] = batch[source_key] * lam + flipped_source * (1 - lam)
|
||||
batch[target_key][num_mixup:] = batch[source_key][num_mixup:] # no mixup
|
||||
return lam
|
||||
|
||||
def __call__(self, batch:dict, _=None) -> None:
|
||||
batch_size:int = len(batch['id'])
|
||||
assert batch_size % 2 == 0, f'Batch size(={batch_size}) should be even when using this'
|
||||
half = 'half' in self.mode
|
||||
if half:
|
||||
batch_size //= 2
|
||||
|
||||
# Mixup
|
||||
if self.mode == 'elem' or self.mode == 'half':
|
||||
collate_fn = partial(self._mix_elem_collate, half=half)
|
||||
elif self.mode == 'pair':
|
||||
collate_fn = self._mix_pair_collate
|
||||
else:
|
||||
collate_fn = self._mix_batch_collate
|
||||
|
||||
if self.modality == 'both':
|
||||
target_keys, source_keys = [self.target_video_key, self.target_audio_key], [self.source_video_key, self.source_audio_key]
|
||||
elif self.modality == 'video':
|
||||
target_keys, source_keys = [self.target_video_key], [self.source_video_key]
|
||||
elif self.modality == 'audio':
|
||||
target_keys, source_keys = [self.target_audio_key], [self.source_audio_key]
|
||||
lam = collate_fn(batch, target_keys=target_keys, source_keys=source_keys)
|
||||
|
||||
# return batch
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import bisect
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
|
||||
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
||||
class MultiModalDataset(Dataset):
|
||||
datasets: list[Dataset]
|
||||
cumulative_sizes: list[int]
|
||||
|
||||
@staticmethod
|
||||
def cumsum(sequence):
|
||||
r, s = [], 0
|
||||
for e in sequence:
|
||||
l = len(e)
|
||||
r.append(l + s)
|
||||
s += l
|
||||
return r
|
||||
|
||||
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
||||
super().__init__()
|
||||
self.video_datasets = list(video_datasets)
|
||||
self.audio_datasets = list(audio_datasets)
|
||||
self.datasets = self.video_datasets + self.audio_datasets
|
||||
|
||||
self.cumulative_sizes = self.cumsum(self.datasets)
|
||||
|
||||
def __len__(self):
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < 0:
|
||||
if -idx > len(self):
|
||||
raise ValueError("absolute value of index should not exceed dataset length")
|
||||
idx = len(self) + idx
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.video_datasets[0].compute_latent_stats()
|
||||
@@ -0,0 +1,148 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from tensordict import MemoryMappedTensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from selva_core.utils.dist_utils import local_rank, world_size
|
||||
|
||||
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
||||
shm_path = Path('/dev/shm')
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
|
||||
def reseed(seed):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def local_scatter_torch(obj: Optional[Any]):
|
||||
if world_size == 1:
|
||||
# Just one worker. Do nothing.
|
||||
return obj
|
||||
|
||||
array = [obj] * world_size
|
||||
target_array = [None]
|
||||
if local_rank == 0:
|
||||
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
||||
else:
|
||||
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
||||
return target_array[0]
|
||||
|
||||
|
||||
class ShardDataset(Dataset):
|
||||
|
||||
def __init__(self, root):
|
||||
self.root = root
|
||||
self.shards = sorted(os.listdir(root))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shards)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
||||
|
||||
|
||||
def get_tmp_dir(in_memory: bool) -> Path:
|
||||
return shm_path if in_memory else scratch_path
|
||||
|
||||
|
||||
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
||||
in_memory: bool) -> MemoryMappedTensor:
|
||||
if local_rank == 0:
|
||||
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
||||
log.info(f'Loading shards from {data_path} into {f.name}...')
|
||||
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
||||
data = share_tensor_to_all(data)
|
||||
torch.distributed.barrier()
|
||||
f.close() # why does the context manager not close the file for me?
|
||||
else:
|
||||
log.info('Waiting for the data to be shared with me...')
|
||||
data = share_tensor_to_all(None)
|
||||
torch.distributed.barrier()
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_shards(
|
||||
data_path: Union[str, Path],
|
||||
ids: list[int],
|
||||
*,
|
||||
tmp_file_path: str,
|
||||
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
|
||||
id_set = set(ids)
|
||||
shards = sorted(os.listdir(data_path))
|
||||
log.info(f'Found {len(shards)} shards in {data_path}.')
|
||||
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
||||
|
||||
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
||||
first_item = next(iter(first_shard.values()))
|
||||
log.info(f'First item shape: {first_item.shape}')
|
||||
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
||||
dtype=torch.float32,
|
||||
filename=tmp_file_path,
|
||||
existsok=True)
|
||||
total_count = 0
|
||||
used_index = set()
|
||||
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
||||
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
||||
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
||||
for data in tqdm(loader, desc='Loading shards'):
|
||||
for i, v in data.items():
|
||||
if i not in id_set:
|
||||
continue
|
||||
|
||||
# tensor_index = ids.index(i)
|
||||
tensor_index = id_indexing[i]
|
||||
if tensor_index in used_index:
|
||||
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
||||
used_index.add(tensor_index)
|
||||
mm_tensor[tensor_index] = v
|
||||
total_count += 1
|
||||
|
||||
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
||||
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
||||
|
||||
return mm_tensor
|
||||
|
||||
|
||||
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
||||
"""
|
||||
x: the tensor to be shared; None if local_rank != 0
|
||||
return: the shared tensor
|
||||
"""
|
||||
|
||||
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
||||
if world_size == 1:
|
||||
return x
|
||||
|
||||
if local_rank == 0:
|
||||
assert x is not None, 'x must not be None if local_rank == 0'
|
||||
else:
|
||||
assert x is None, 'x must be None if local_rank != 0'
|
||||
|
||||
if local_rank == 0:
|
||||
filename = x.filename
|
||||
meta_information = (filename, x.shape, x.dtype)
|
||||
else:
|
||||
meta_information = None
|
||||
|
||||
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
||||
if local_rank == 0:
|
||||
data = x
|
||||
else:
|
||||
data = MemoryMappedTensor.from_filename(filename=filename,
|
||||
dtype=data_type,
|
||||
shape=data_shape)
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,299 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torchvision.transforms import v2
|
||||
from torio.io import StreamingMediaDecoder
|
||||
from tensordict import TensorDict
|
||||
|
||||
from selva_core.data.av_utils import normalize_video_chunk
|
||||
from selva_core.utils.dist_utils import local_rank
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
_CLIP_SIZE = 384
|
||||
_CLIP_FPS = 8.0
|
||||
|
||||
_SYNC_SIZE = 224
|
||||
_SYNC_FPS = 25.0
|
||||
|
||||
|
||||
class VGGSound(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[str, Path],
|
||||
*,
|
||||
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
||||
for_generator: bool = True,
|
||||
audio_required: bool = False,
|
||||
sample_rate: int = 16_000,
|
||||
duration_sec: float = 8.0,
|
||||
audio_samples: Optional[int] = None,
|
||||
normalize_audio: bool = False,
|
||||
clip_video_required: bool = False,
|
||||
mmap_dir: Union[str, Path] = None,
|
||||
tsv_tsynch_path: Union[str, Path] = None,
|
||||
mmap_tsync_dir: Union[str, Path] = None,
|
||||
data_dim: dict[str, int] = None,
|
||||
):
|
||||
self.root = Path(root)
|
||||
self.audio_required = audio_required
|
||||
if audio_required:
|
||||
self.normalize_audio = normalize_audio
|
||||
if audio_samples is None:
|
||||
self.audio_samples = int(sample_rate * duration_sec)
|
||||
else:
|
||||
self.audio_samples = audio_samples
|
||||
effective_duration = audio_samples / sample_rate
|
||||
# make sure the duration is close enough, within 15ms
|
||||
assert abs(effective_duration - duration_sec) < 0.015, \
|
||||
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
||||
self.clip_video_required = clip_video_required
|
||||
self.for_generator = for_generator
|
||||
|
||||
videos = sorted(os.listdir(self.root))
|
||||
videos = set([Path(v).stem for v in videos]) # remove extensions
|
||||
self.labels = {}
|
||||
self.videos = []
|
||||
missing_videos = []
|
||||
|
||||
# read the tsv for subset information
|
||||
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
||||
for record in df_list:
|
||||
id = record['id']
|
||||
label = record['label']
|
||||
if id in videos:
|
||||
self.labels[id] = label
|
||||
self.videos.append(id)
|
||||
else:
|
||||
missing_videos.append(id)
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'{len(videos)} videos found in {root}')
|
||||
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
||||
log.info(f'{len(missing_videos)} videos missing in {root}')
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.duration_sec = duration_sec
|
||||
|
||||
if audio_required:
|
||||
self.expected_audio_length = self.audio_samples
|
||||
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
||||
if clip_video_required:
|
||||
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
||||
|
||||
self.sync_transform = v2.Compose([
|
||||
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
# v2.CenterCrop(_SYNC_SIZE),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
|
||||
if clip_video_required:
|
||||
self.clip_transform = v2.Compose([
|
||||
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
||||
v2.ToImage(),
|
||||
v2.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
if audio_required:
|
||||
self.resampler = {}
|
||||
|
||||
# mmap
|
||||
log.info(f'Loading precomputed mmap from {mmap_dir}')
|
||||
mmap_dir = Path(mmap_dir)
|
||||
td = TensorDict.load_memmap(mmap_dir)
|
||||
log.info(f'Loaded precomputed mmap from {mmap_dir}')
|
||||
self.sync_features = td['sync_features']
|
||||
if for_generator:
|
||||
self.mean = td['mean']
|
||||
self.std = td['std']
|
||||
self.text_clip_features = td['text_features']
|
||||
if clip_video_required:
|
||||
self.clip_features = td['clip_features']
|
||||
else:
|
||||
self.clip_features = None
|
||||
self.id2idx_mmap = {d['id']: i for i, d in enumerate(df_list)}
|
||||
|
||||
mmap_tsync_dir = Path(mmap_tsync_dir)
|
||||
td_tsync = TensorDict.load_memmap(mmap_tsync_dir)
|
||||
log.info(f'Loaded precomputed tsync mmap from {mmap_tsync_dir}')
|
||||
self.text_features = td_tsync['text_features']
|
||||
self.text_masks = td_tsync['text_masks']
|
||||
df_list_tsync = pd.read_csv(tsv_tsynch_path, sep='\t').to_dict('records')
|
||||
self.id2idx_mmap_tsync = {d['id']: i for i, d in enumerate(df_list_tsync)}
|
||||
|
||||
if local_rank == 0:
|
||||
log.info(f'Loaded {len(self)} samples.')
|
||||
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
||||
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
||||
log.info(f'Loaded text_masks: {self.text_masks.shape}.')
|
||||
if for_generator:
|
||||
log.info(f'Loaded mean: {self.mean.shape}.')
|
||||
log.info(f'Loaded std: {self.std.shape}.')
|
||||
log.info(f'Loaded text_clip_features: {self.text_clip_features.shape}.')
|
||||
if clip_video_required:
|
||||
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
||||
|
||||
assert self.sync_features.shape[1] == data_dim['sync_seq_len'], \
|
||||
f'{self.sync_features.shape[1]} != {data_dim["sync_seq_len"]}'
|
||||
assert self.text_features.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||
f'{self.text_features.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||
assert self.text_masks.shape[1] <= data_dim['text_flant5_max_seq_len'], \
|
||||
f'{self.text_masks.shape[1]} > {data_dim["text_flant5_max_seq_len"]}'
|
||||
assert self.sync_features.shape[-1] == data_dim['sync_dim'], \
|
||||
f'{self.sync_features.shape[-1]} != {data_dim["sync_dim"]}'
|
||||
assert self.text_features.shape[-1] == data_dim['text_flant5_dim'], \
|
||||
f'{self.text_features.shape[-1]} != {data_dim["text_flant5_dim"]}'
|
||||
if for_generator:
|
||||
assert self.mean.shape[1] == data_dim['latent_seq_len'], \
|
||||
f'{self.mean.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||
assert self.std.shape[1] == data_dim['latent_seq_len'], \
|
||||
f'{self.std.shape[1]} != {data_dim["latent_seq_len"]}'
|
||||
assert self.text_clip_features.shape[1] == data_dim['text_clip_seq_len'], \
|
||||
f'{self.text_clip_features.shape[1]} != {data_dim["text_clip_seq_len"]}'
|
||||
assert self.text_clip_features.shape[-1] == data_dim['text_clip_dim'], \
|
||||
f'{self.text_clip_features.shape[-1]} != {data_dim["text_clip_dim"]}'
|
||||
if clip_video_required:
|
||||
assert self.clip_features.shape[1] == data_dim['clip_seq_len'], \
|
||||
f'{self.clip_features.shape[1]} != {data_dim["clip_seq_len"]}'
|
||||
assert self.clip_features.shape[-1] == data_dim['clip_dim'], \
|
||||
f'{self.clip_features.shape[-1]} != {data_dim["clip_dim"]}'
|
||||
|
||||
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
||||
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
||||
|
||||
|
||||
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: # mmap
|
||||
latents = self.mean
|
||||
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
||||
|
||||
def get_memory_mapped_tensor(self) -> TensorDict:
|
||||
td = TensorDict({
|
||||
'sync_features': self.sync_features,
|
||||
'text_features': self.text_features,
|
||||
'text_masks': self.text_masks,
|
||||
})
|
||||
if self.for_generator:
|
||||
td['mean'] = self.mean
|
||||
td['std'] = self.std
|
||||
td['text_clip_features'] = self.text_clip_features
|
||||
if self.clip_video_required:
|
||||
td['clip_features'] = self.clip_features
|
||||
return td
|
||||
|
||||
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
video_id = self.videos[idx]
|
||||
|
||||
if video_id in self.captions and torch.rand(1).item() < self.autoacd_sample_prob:
|
||||
label = self.captions[video_id]
|
||||
else:
|
||||
label = self.labels[video_id]
|
||||
|
||||
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
||||
frame_rate=_SYNC_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
if self.audio_required:
|
||||
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
||||
if self.clip_video_required:
|
||||
reader.add_basic_video_stream(
|
||||
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
||||
frame_rate=_CLIP_FPS,
|
||||
format='rgb24',
|
||||
)
|
||||
|
||||
reader.fill_buffer()
|
||||
data_chunk = reader.pop_chunks()
|
||||
|
||||
sync_chunk = data_chunk[0]
|
||||
if sync_chunk is None:
|
||||
raise RuntimeError(f'Sync video returned None {video_id}')
|
||||
sync_chunk = normalize_video_chunk(sync_chunk, self.sync_expected_length,
|
||||
n_tolerance_frame=3, desc=video_id)
|
||||
sync_chunk = self.sync_transform(sync_chunk)
|
||||
|
||||
if self.audio_required:
|
||||
audio_chunk = data_chunk[1]
|
||||
|
||||
if self.clip_video_required:
|
||||
clip_chunk = data_chunk[2 if self.audio_required else 1]
|
||||
if clip_chunk is None:
|
||||
raise RuntimeError(f'CLIP video returned None {video_id}')
|
||||
clip_chunk = normalize_video_chunk(clip_chunk, self.clip_expected_length,
|
||||
n_tolerance_frame=1, desc=video_id)
|
||||
clip_chunk = self.clip_transform(clip_chunk)
|
||||
|
||||
# process audio
|
||||
if self.audio_required:
|
||||
sample_rate = int(reader.get_out_stream_info(1).sample_rate)
|
||||
audio_chunk = audio_chunk.transpose(0, 1)
|
||||
audio_chunk = audio_chunk.mean(dim=0) # mono
|
||||
if self.normalize_audio:
|
||||
abs_max = audio_chunk.abs().max()
|
||||
audio_chunk = audio_chunk * (0.95 / abs_max)
|
||||
if abs_max <= 1e-6:
|
||||
raise RuntimeError(f'Audio is silent {video_id}')
|
||||
|
||||
# resample
|
||||
if sample_rate == self.sample_rate:
|
||||
audio_chunk = audio_chunk
|
||||
else:
|
||||
if sample_rate not in self.resampler:
|
||||
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
||||
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
||||
sample_rate,
|
||||
self.sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method='sinc_interp_kaiser',
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
||||
|
||||
if audio_chunk.shape[0] < self.expected_audio_length:
|
||||
raise RuntimeError(f'Audio too short {video_id}')
|
||||
audio_chunk = audio_chunk[:self.expected_audio_length]
|
||||
|
||||
data = {
|
||||
'id': video_id,
|
||||
'caption': label,
|
||||
'sync_video': sync_chunk,
|
||||
'sync_f_vid_orig': self.sync_features[self.id2idx_mmap[video_id]],
|
||||
'text_features': self.text_features[self.id2idx_mmap_tsync[video_id]],
|
||||
'text_masks': self.text_masks[self.id2idx_mmap_tsync[video_id]],
|
||||
'video_exist': self.video_exist,
|
||||
'text_exist': self.text_exist,
|
||||
}
|
||||
|
||||
if self.for_generator:
|
||||
data['a_mean'] = self.mean[self.id2idx_mmap[video_id]]
|
||||
data['a_std'] = self.std[self.id2idx_mmap[video_id]]
|
||||
data['text_clip_features'] = self.text_clip_features[self.id2idx_mmap[video_id]]
|
||||
|
||||
if self.audio_required:
|
||||
data['audio'] = audio_chunk
|
||||
|
||||
if self.clip_video_required:
|
||||
data['clip_video'] = clip_chunk
|
||||
data['clip_features'] = self.clip_features[self.id2idx_mmap[video_id]],
|
||||
|
||||
return data
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
try:
|
||||
return self.sample(idx)
|
||||
except Exception as e:
|
||||
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .autoencoder import AutoEncoderModule
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from selva_core.ext.autoencoder.vae import VAE, get_my_vae
|
||||
from selva_core.ext.bigvgan import BigVGAN
|
||||
from selva_core.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class AutoEncoderModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vae_ckpt_path,
|
||||
vocoder_ckpt_path: Optional[str] = None,
|
||||
mode: Literal['16k', '44k'],
|
||||
need_vae_encoder: bool = True):
|
||||
super().__init__()
|
||||
self.vae: VAE = get_my_vae(mode).eval()
|
||||
vae_state_dict = torch.load(vae_ckpt_path, weights_only=False, map_location='cpu')
|
||||
self.vae.load_state_dict(vae_state_dict)
|
||||
self.vae.remove_weight_norm()
|
||||
|
||||
if mode == '16k':
|
||||
assert vocoder_ckpt_path is not None
|
||||
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
||||
elif mode == '44k':
|
||||
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
||||
use_cuda_kernel=False)
|
||||
self.vocoder.remove_weight_norm()
|
||||
else:
|
||||
raise ValueError(f'Unknown mode: {mode}')
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if not need_vae_encoder:
|
||||
del self.vae.encoder
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
||||
return self.vae.encode(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.vae.decode(z)
|
||||
|
||||
@torch.inference_mode()
|
||||
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.vocoder(spec)
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# This work is licensed under a Creative Commons
|
||||
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
||||
# You should have received a copy of the license along with this
|
||||
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
"""Improved diffusion model architecture proposed in the paper
|
||||
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Variant of constant() that inherits dtype and device from the given
|
||||
# reference tensor by default.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
|
||||
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
if dtype is None:
|
||||
dtype = ref.dtype
|
||||
if device is None:
|
||||
device = ref.device
|
||||
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Normalize given tensor to unit magnitude with respect to the given
|
||||
# dimensions. Default = all dimensions except the first.
|
||||
|
||||
|
||||
def normalize(x, dim=None, eps=1e-4):
|
||||
if dim is None:
|
||||
dim = list(range(1, x.ndim))
|
||||
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
||||
return x / norm.to(x.dtype)
|
||||
|
||||
|
||||
class Normalize(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim=None, eps=1e-4):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return normalize(x, dim=self.dim, eps=self.eps)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Upsample or downsample the given tensor with the given filter,
|
||||
# or keep it as is.
|
||||
|
||||
|
||||
def resample(x, f=[1, 1], mode='keep'):
|
||||
if mode == 'keep':
|
||||
return x
|
||||
f = np.float32(f)
|
||||
assert f.ndim == 1 and len(f) % 2 == 0
|
||||
pad = (len(f) - 1) // 2
|
||||
f = f / f.sum()
|
||||
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
||||
f = const_like(x, f)
|
||||
c = x.shape[1]
|
||||
if mode == 'down':
|
||||
return torch.nn.functional.conv2d(x,
|
||||
f.tile([c, 1, 1, 1]),
|
||||
groups=c,
|
||||
stride=2,
|
||||
padding=(pad, ))
|
||||
assert mode == 'up'
|
||||
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
|
||||
groups=c,
|
||||
stride=2,
|
||||
padding=(pad, ))
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Magnitude-preserving SiLU (Equation 81).
|
||||
|
||||
|
||||
def mp_silu(x):
|
||||
return torch.nn.functional.silu(x) / 0.596
|
||||
|
||||
|
||||
class MPSiLU(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return mp_silu(x)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Magnitude-preserving sum (Equation 88).
|
||||
|
||||
|
||||
def mp_sum(a, b, t=0.5):
|
||||
return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Magnitude-preserving concatenation (Equation 103).
|
||||
|
||||
|
||||
def mp_cat(a, b, dim=1, t=0.5):
|
||||
Na = a.shape[dim]
|
||||
Nb = b.shape[dim]
|
||||
C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
|
||||
wa = C / np.sqrt(Na) * (1 - t)
|
||||
wb = C / np.sqrt(Nb) * t
|
||||
return torch.cat([wa * a, wb * b], dim=dim)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
||||
# with force weight normalization (Equation 66).
|
||||
|
||||
|
||||
class MPConv1D(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
|
||||
|
||||
self.weight_norm_removed = False
|
||||
|
||||
def forward(self, x, gain=1):
|
||||
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
||||
|
||||
w = self.weight * gain
|
||||
if w.ndim == 2:
|
||||
return x @ w.t()
|
||||
assert w.ndim == 3
|
||||
return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
|
||||
|
||||
def remove_weight_norm(self):
|
||||
w = self.weight.to(torch.float32)
|
||||
w = normalize(w) # traditional weight normalization
|
||||
w = w / np.sqrt(w[0].numel())
|
||||
w = w.to(self.weight.dtype)
|
||||
self.weight.data.copy_(w)
|
||||
|
||||
self.weight_norm_removed = True
|
||||
return self
|
||||
@@ -0,0 +1,369 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from selva_core.ext.autoencoder.edm2_utils import MPConv1D
|
||||
from selva_core.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
Upsample1D, nonlinearity)
|
||||
from selva_core.model.utils.distributions import DiagonalGaussianDistribution
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
DATA_MEAN_80D = [
|
||||
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||
]
|
||||
|
||||
DATA_STD_80D = [
|
||||
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||
]
|
||||
|
||||
DATA_MEAN_128D = [
|
||||
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||
]
|
||||
|
||||
DATA_STD_128D = [
|
||||
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||
]
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
data_dim: int,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if data_dim == 80:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||
elif data_dim == 128:
|
||||
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||
|
||||
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||
self.data_std = self.data_std.view(1, -1, 1)
|
||||
|
||||
self.encoder = Encoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
self.decoder = Decoder1D(
|
||||
dim=hidden_dim,
|
||||
ch_mult=(1, 2, 4),
|
||||
num_res_blocks=2,
|
||||
attn_layers=[3],
|
||||
down_layers=[0],
|
||||
in_dim=data_dim,
|
||||
out_dim=data_dim,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
pass
|
||||
|
||||
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||
if normalize:
|
||||
x = self.normalize(x)
|
||||
moments = self.encoder(x)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||
dec = self.decoder(z)
|
||||
if unnormalize:
|
||||
dec = self.unnormalize(dec)
|
||||
return dec
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - self.data_mean) / self.data_std
|
||||
|
||||
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * self.data_std + self.data_mean
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sample_posterior: bool = True,
|
||||
rng: Optional[torch.Generator] = None,
|
||||
normalize: bool = True,
|
||||
unnormalize: bool = True,
|
||||
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||
|
||||
posterior = self.encode(x, normalize=normalize)
|
||||
if sample_posterior:
|
||||
z = posterior.sample(rng)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, unnormalize=unnormalize)
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, MPConv1D):
|
||||
m.remove_weight_norm()
|
||||
log.debug(f"Removed weight norm from {name}")
|
||||
return self
|
||||
|
||||
|
||||
class Encoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
double_z: bool = True,
|
||||
kernel_size: int = 3,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = down_layers
|
||||
self.attn_layers = attn_layers
|
||||
self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
|
||||
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
# downsampling
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_layers):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = dim * in_ch_mult[i_level]
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_out,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level in down_layers:
|
||||
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||
out_dim=block_in,
|
||||
kernel_size=kernel_size,
|
||||
use_norm=True)
|
||||
|
||||
# end
|
||||
self.conv_out = MPConv1D(block_in,
|
||||
2 * embed_dim if double_z else embed_dim,
|
||||
kernel_size=kernel_size)
|
||||
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_layers):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
hs.append(h)
|
||||
if i_level in self.down_layers:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# end
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
||||
return h
|
||||
|
||||
|
||||
class Decoder1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
attn_layers: list[int] = [],
|
||||
down_layers: list[int] = [],
|
||||
kernel_size: int = 3,
|
||||
resamp_with_conv: bool = True,
|
||||
in_dim: int,
|
||||
embed_dim: int,
|
||||
clip_act: float = 256.0):
|
||||
super().__init__()
|
||||
self.ch = dim
|
||||
self.num_layers = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_dim
|
||||
self.clip_act = clip_act
|
||||
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = dim * ch_mult[self.num_layers - 1]
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = dim * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||
block_in = block_out
|
||||
if i_level in attn_layers:
|
||||
attn.append(AttnBlock1D(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level in self.down_layers:
|
||||
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
|
||||
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||
|
||||
def forward(self, z):
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_layers)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
h = h.clamp(-self.clip_act, self.clip_act)
|
||||
if i_level in self.down_layers:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
||||
return h
|
||||
|
||||
|
||||
def VAE_16k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||
|
||||
|
||||
def VAE_44k(**kwargs) -> VAE:
|
||||
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||
|
||||
|
||||
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||
if name == '16k':
|
||||
return VAE_16k(**kwargs)
|
||||
if name == '44k':
|
||||
return VAE_44k(**kwargs)
|
||||
raise ValueError(f'Unknown model: {name}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
network = get_my_vae('standard')
|
||||
|
||||
# print the number of parameters in terms of millions
|
||||
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
||||
print(f'Number of parameters: {num_params:.2f}M')
|
||||
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from selva_core.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return mp_silu(x)
|
||||
|
||||
|
||||
class ResnetBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
out_dim = in_dim if out_dim is None else out_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.use_norm = use_norm
|
||||
|
||||
self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
|
||||
self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
|
||||
else:
|
||||
self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# pixel norm
|
||||
if self.use_norm:
|
||||
x = normalize(x, dim=1)
|
||||
|
||||
h = x
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = nonlinearity(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_dim != self.out_dim:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class AttnBlock1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, num_heads=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
|
||||
self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
y = self.qkv(h)
|
||||
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
|
||||
q, k, v = normalize(y, dim=2).unbind(3)
|
||||
|
||||
q = rearrange(q, 'b h c l -> b h l c')
|
||||
k = rearrange(k, 'b h c l -> b h l c')
|
||||
v = rearrange(v, 'b h c l -> b h l c')
|
||||
|
||||
h = F.scaled_dot_product_attention(q, k, v)
|
||||
h = rearrange(h, 'b h l c -> b (h c) l')
|
||||
|
||||
h = self.proj_out(h)
|
||||
|
||||
return mp_sum(x, h, t=0.3)
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||
self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv1(x)
|
||||
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1 @@
|
||||
from .bigvgan import BigVGAN
|
||||
@@ -0,0 +1,120 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
'''
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
@@ -0,0 +1,28 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,95 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if 'sinc' in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(x == 0,
|
||||
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = (kernel_size % 2 == 0)
|
||||
half_size = kernel_size // 2
|
||||
|
||||
#For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.:
|
||||
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||
else:
|
||||
beta = 0.
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||
# of the constant component in the input signal.
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = 'replicate',
|
||||
kernel_size: int = 12):
|
||||
# kernel_size should be even number for stylegan3 setup,
|
||||
# in this implementation, odd number is also possible.
|
||||
super().__init__()
|
||||
if cutoff < -0.:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = (kernel_size % 2 == 0)
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
#input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||
mode=self.padding_mode)
|
||||
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
||||
stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,49 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from selva_core.ext.bigvgan.models import BigVGANVocoder
|
||||
|
||||
_bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
|
||||
|
||||
|
||||
class BigVGAN(nn.Module):
|
||||
|
||||
def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
|
||||
super().__init__()
|
||||
vocoder_cfg = OmegaConf.load(config_path)
|
||||
self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
|
||||
vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)['generator']
|
||||
self.vocoder.load_state_dict(vocoder_ckpt)
|
||||
|
||||
self.weight_norm_removed = False
|
||||
self.remove_weight_norm()
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x):
|
||||
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
||||
return self.vocoder(x)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.vocoder.remove_weight_norm()
|
||||
self.weight_norm_removed = True
|
||||
return self
|
||||
@@ -0,0 +1,63 @@
|
||||
resblock: '1'
|
||||
num_gpus: 0
|
||||
batch_size: 64
|
||||
num_mels: 80
|
||||
learning_rate: 0.0001
|
||||
adam_b1: 0.8
|
||||
adam_b2: 0.99
|
||||
lr_decay: 0.999
|
||||
seed: 1234
|
||||
upsample_rates:
|
||||
- 4
|
||||
- 4
|
||||
- 2
|
||||
- 2
|
||||
- 2
|
||||
- 2
|
||||
upsample_kernel_sizes:
|
||||
- 8
|
||||
- 8
|
||||
- 4
|
||||
- 4
|
||||
- 4
|
||||
- 4
|
||||
upsample_initial_channel: 1536
|
||||
resblock_kernel_sizes:
|
||||
- 3
|
||||
- 7
|
||||
- 11
|
||||
resblock_dilation_sizes:
|
||||
- - 1
|
||||
- 3
|
||||
- 5
|
||||
- - 1
|
||||
- 3
|
||||
- 5
|
||||
- - 1
|
||||
- 3
|
||||
- 5
|
||||
activation: snakebeta
|
||||
snake_logscale: true
|
||||
resolutions:
|
||||
- - 1024
|
||||
- 120
|
||||
- 600
|
||||
- - 2048
|
||||
- 240
|
||||
- 1200
|
||||
- - 512
|
||||
- 50
|
||||
- 240
|
||||
mpd_reshapes:
|
||||
- 2
|
||||
- 3
|
||||
- 5
|
||||
- 7
|
||||
- 11
|
||||
use_spectral_norm: false
|
||||
discriminator_channel_mult: 1
|
||||
num_workers: 4
|
||||
dist_config:
|
||||
dist_backend: nccl
|
||||
dist_url: tcp://localhost:54341
|
||||
world_size: 1
|
||||
@@ -0,0 +1,18 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Jungil Kong
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Edward Dixon
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2019, Seungwon Park 박승원
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -0,0 +1,16 @@
|
||||
Copyright 2020 Alexandre Défossez
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or
|
||||
substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from selva_core.ext.bigvgan import activations
|
||||
from selva_core.ext.bigvgan.alias_free_torch import *
|
||||
from selva_core.ext.bigvgan.utils import get_padding, init_weights
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||
super(AMPBlock1, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, 'weight')
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, 'weight')
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||
super(AMPBlock2, self).__init__()
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs) # total number of conv layers
|
||||
|
||||
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_parametrizations(l, 'weight')
|
||||
|
||||
|
||||
class BigVGANVocoder(torch.nn.Module):
|
||||
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||
def __init__(self, h):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# pre conv
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||
|
||||
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
weight_norm(
|
||||
ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2))
|
||||
]))
|
||||
|
||||
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# post conv
|
||||
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
||||
|
||||
# weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
self.ups[i].apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
# pre conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# post conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
for l_i in l:
|
||||
remove_parametrizations(l_i, 'weight')
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, 'weight')
|
||||
remove_parametrizations(self.conv_post, 'weight')
|
||||
@@ -0,0 +1,31 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print("Loading '{}'".format(filepath))
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,126 @@
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
"""
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# Initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # Linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
"""
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
References:
|
||||
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snakebeta(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha - trainable parameter that controls frequency
|
||||
- beta - trainable parameter that controls magnitude
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||
alpha will be trained along with the rest of your model.
|
||||
"""
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# Initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # Linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||
"""
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||
|
||||
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
anti_alias_activation_cuda = load.load()
|
||||
|
||||
|
||||
class FusedAntiAliasActivation(torch.autograd.Function):
|
||||
"""
|
||||
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
||||
The hyperparameters are hard-coded in the kernel to maximize speed.
|
||||
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||
activation_results = anti_alias_activation_cuda.forward(
|
||||
inputs, up_ftr, down_ftr, alpha, beta
|
||||
)
|
||||
|
||||
return activation_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
raise NotImplementedError
|
||||
return output_grads, None, None
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
fused: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
self.fused = fused # Whether to use fused CUDA kernel or not
|
||||
|
||||
def forward(self, x):
|
||||
if not self.fused:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
else:
|
||||
if self.act.__class__.__name__ == "Snake":
|
||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||
else:
|
||||
beta = (
|
||||
self.act.beta.data
|
||||
) # Snakebeta uses different params for alpha and beta
|
||||
alpha = self.act.alpha.data
|
||||
if (
|
||||
not self.act.alpha_logscale
|
||||
): # Exp baked into cuda kernel, cancel it out with a log
|
||||
alpha = torch.log(alpha)
|
||||
beta = torch.log(beta)
|
||||
|
||||
x = FusedAntiAliasActivation.apply(
|
||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||
)
|
||||
return x
|
||||
@@ -0,0 +1,23 @@
|
||||
/* coding=utf-8
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
/* coding=utf-8
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include "type_shim.h"
|
||||
#include <assert.h>
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace
|
||||
{
|
||||
// Hard-coded hyperparameters
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
constexpr int BUFFER_SIZE = 32;
|
||||
constexpr int FILTER_SIZE = 12;
|
||||
constexpr int HALF_FILTER_SIZE = 6;
|
||||
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
||||
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
||||
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
__global__ void anti_alias_activation_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t *up_ftr,
|
||||
const input_t *down_ftr,
|
||||
const input_t *alpha,
|
||||
const input_t *beta,
|
||||
int batch_size,
|
||||
int channels,
|
||||
int seq_len)
|
||||
{
|
||||
// Up and downsample filters
|
||||
input_t up_filter[FILTER_SIZE];
|
||||
input_t down_filter[FILTER_SIZE];
|
||||
|
||||
// Load data from global memory including extra indices reserved for replication paddings
|
||||
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
||||
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
||||
|
||||
// Output stores downsampled output before writing to dst
|
||||
output_t output[BUFFER_SIZE];
|
||||
|
||||
// blockDim/threadIdx = (128, 1, 1)
|
||||
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
||||
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||
int local_offset = threadIdx.x * BUFFER_SIZE;
|
||||
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
||||
|
||||
// intermediate have double the seq_len
|
||||
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
||||
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
||||
|
||||
// Get values needed for replication padding before moving pointer
|
||||
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||
input_t seq_left_most_value = right_most_pntr[0];
|
||||
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
||||
|
||||
// Move src and dst pointers
|
||||
src += block_offset + local_offset;
|
||||
dst += block_offset + local_offset;
|
||||
|
||||
// Alpha and beta values for snake activatons. Applies exp by default
|
||||
alpha = alpha + blockIdx.y;
|
||||
input_t alpha_val = expf(alpha[0]);
|
||||
beta = beta + blockIdx.y;
|
||||
input_t beta_val = expf(beta[0]);
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < FILTER_SIZE; it += 1)
|
||||
{
|
||||
up_filter[it] = up_ftr[it];
|
||||
down_filter[it] = down_ftr[it];
|
||||
}
|
||||
|
||||
// Apply replication padding for upsampling, matching torch impl
|
||||
#pragma unroll
|
||||
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
||||
{
|
||||
int element_index = seq_offset + it; // index for element
|
||||
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
||||
{
|
||||
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
||||
}
|
||||
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
||||
{
|
||||
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
||||
}
|
||||
if ((element_index >= 0) && (element_index < seq_len))
|
||||
{
|
||||
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
||||
#pragma unroll
|
||||
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
||||
{
|
||||
input_t acc = 0.0;
|
||||
int element_index = intermediate_seq_offset + it; // index for intermediate
|
||||
#pragma unroll
|
||||
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||
{
|
||||
if ((element_index + f_idx) >= 0)
|
||||
{
|
||||
acc += up_filter[f_idx] * elements[it + f_idx];
|
||||
}
|
||||
}
|
||||
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
||||
}
|
||||
|
||||
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
||||
double no_div_by_zero = 0.000000001;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
||||
{
|
||||
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
||||
}
|
||||
|
||||
// Apply replication padding before downsampling conv from intermediates
|
||||
#pragma unroll
|
||||
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
||||
{
|
||||
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
||||
{
|
||||
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
||||
}
|
||||
|
||||
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
||||
#pragma unroll
|
||||
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
||||
{
|
||||
input_t acc = 0.0;
|
||||
#pragma unroll
|
||||
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||
{
|
||||
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
||||
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
||||
}
|
||||
output[it] = acc;
|
||||
}
|
||||
|
||||
// Write output to dst
|
||||
#pragma unroll
|
||||
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
||||
{
|
||||
int element_index = seq_offset + it;
|
||||
if (element_index < seq_len)
|
||||
{
|
||||
dst[it] = output[it];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_anti_alias_activation_forward(
|
||||
output_t *dst,
|
||||
const input_t *src,
|
||||
const input_t *up_ftr,
|
||||
const input_t *down_ftr,
|
||||
const input_t *alpha,
|
||||
const input_t *beta,
|
||||
int batch_size,
|
||||
int channels,
|
||||
int seq_len)
|
||||
{
|
||||
if (seq_len == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
constexpr int seq_len_per_block = 4096;
|
||||
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
||||
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
||||
dim3 threads(threads_per_block, 1, 1);
|
||||
|
||||
anti_alias_activation_forward<input_t, output_t, acc_t>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
||||
{
|
||||
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int seq_len = input.size(2);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
|
||||
torch::Tensor anti_alias_activation_results =
|
||||
torch::empty({batches, channels, seq_len}, act_options);
|
||||
|
||||
void *input_ptr = static_cast<void *>(input.data_ptr());
|
||||
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
||||
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
||||
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
||||
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
||||
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch anti alias activation_forward",
|
||||
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
||||
reinterpret_cast<const scalar_t *>(input_ptr),
|
||||
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
||||
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
||||
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
||||
reinterpret_cast<const scalar_t *>(beta_ptr),
|
||||
batches,
|
||||
channels,
|
||||
seq_len););
|
||||
return anti_alias_activation_results;
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*This code is copied fron NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
"""
|
||||
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
||||
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
||||
"""
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||
|
||||
|
||||
def load():
|
||||
# Check if cuda 11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
|
||||
# Build path
|
||||
srcpath = pathlib.Path(__file__).parent.absolute()
|
||||
buildpath = srcpath / "build"
|
||||
_create_build_dir(buildpath)
|
||||
|
||||
# Helper function to build the kernels.
|
||||
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||
return cpp_extension.load(
|
||||
name=name,
|
||||
sources=sources,
|
||||
build_directory=buildpath,
|
||||
extra_cflags=[
|
||||
"-O3",
|
||||
],
|
||||
extra_cuda_cflags=[
|
||||
"-O3",
|
||||
"-gencode",
|
||||
"arch=compute_70,code=sm_70",
|
||||
"--use_fast_math",
|
||||
]
|
||||
+ extra_cuda_flags
|
||||
+ cc_flag,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
extra_cuda_flags = [
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
]
|
||||
|
||||
sources = [
|
||||
srcpath / "anti_alias_activation.cpp",
|
||||
srcpath / "anti_alias_activation_cuda.cu",
|
||||
]
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
|
||||
return anti_alias_activation_cuda
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def _create_build_dir(buildpath):
|
||||
try:
|
||||
os.mkdir(buildpath)
|
||||
except OSError:
|
||||
if not os.path.isdir(buildpath):
|
||||
print(f"Creation of the build directory {buildpath} failed")
|
||||
@@ -0,0 +1,92 @@
|
||||
/* coding=utf-8
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "compat.h"
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_in = float; \
|
||||
switch (TYPEOUT) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: \
|
||||
{ \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from .filter import *
|
||||
from .resample import *
|
||||
from .act import *
|
||||
@@ -0,0 +1,32 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.resample import (DownSample1d, UpSample1d)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
):
|
||||
super().__init__()
|
||||
self.up_ratio = up_ratio
|
||||
self.down_ratio = down_ratio
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
# x: [B,C,T]
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,101 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
if "sinc" in dir(torch):
|
||||
sinc = torch.sinc
|
||||
else:
|
||||
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/core.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def sinc(x: torch.Tensor):
|
||||
"""
|
||||
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||
"""
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff, half_width, kernel_size
|
||||
): # return filter [1,1,kernel_size]
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
|
||||
# For kaiser window
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||
"""
|
||||
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
||||
"""
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride: int = 1,
|
||||
padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
kernel_size: int = 12,
|
||||
):
|
||||
"""
|
||||
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
||||
"""
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# Input [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,54 @@
|
||||
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.filter import (LowPassFilter1d,
|
||||
kaiser_sinc_filter1d)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2)
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left:-self.pad_right]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.lowpass(x)
|
||||
|
||||
return xx
|
||||
@@ -0,0 +1,439 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
from selva_core.ext.bigvgan_v2 import activations
|
||||
from selva_core.ext.bigvgan_v2.alias_free_activation.torch.act import \
|
||||
Activation1d as TorchActivation1d
|
||||
from selva_core.ext.bigvgan_v2.env import AttrDict
|
||||
from selva_core.ext.bigvgan_v2.utils import get_padding, init_weights
|
||||
|
||||
|
||||
def load_hparams_from_json(path) -> AttrDict:
|
||||
with open(path) as f:
|
||||
data = f.read()
|
||||
return AttrDict(json.loads(data))
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
"""
|
||||
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
channels (int): Number of convolution channels.
|
||||
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple = (1, 3, 5),
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)) for d in dilation
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
)) for _ in range(len(dilation))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
# Activation functions
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_parametrizations(l, 'weight')
|
||||
for l in self.convs2:
|
||||
remove_parametrizations(l, 'weight')
|
||||
|
||||
|
||||
class AMPBlock2(torch.nn.Module):
|
||||
"""
|
||||
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
channels (int): Number of convolution channels.
|
||||
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: tuple = (1, 3, 5),
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=get_padding(kernel_size, d),
|
||||
)) for d in dilation
|
||||
])
|
||||
self.convs.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
# Activation functions
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList([
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c, a in zip(self.convs, self.activations):
|
||||
xt = a(x)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class BigVGAN(
|
||||
torch.nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="bigvgan",
|
||||
repo_url="https://github.com/NVIDIA/BigVGAN",
|
||||
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
||||
pipeline_tag="audio-to-audio",
|
||||
license="mit",
|
||||
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||
):
|
||||
"""
|
||||
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
||||
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
||||
|
||||
Args:
|
||||
h (AttrDict): Hyperparameters.
|
||||
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
||||
|
||||
Note:
|
||||
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
||||
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
||||
"""
|
||||
|
||||
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
||||
super().__init__()
|
||||
self.h = h
|
||||
self.h["use_cuda_kernel"] = use_cuda_kernel
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import \
|
||||
Activation1d as CudaActivation1d
|
||||
|
||||
Activation1d = CudaActivation1d
|
||||
else:
|
||||
Activation1d = TorchActivation1d
|
||||
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
resblock_class = AMPBlock1
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
nn.ModuleList([
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
))
|
||||
]))
|
||||
|
||||
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# Post-conv
|
||||
activation_post = (activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake" else
|
||||
(activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta" else None))
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||
)
|
||||
|
||||
self.activation_post = Activation1d(activation=activation_post)
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
self.ups[i].apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
|
||||
# Final tanh activation. Defaults to True for backward compatibility
|
||||
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
||||
|
||||
def forward(self, x):
|
||||
# Pre-conv
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
# Upsampling
|
||||
for i_up in range(len(self.ups[i])):
|
||||
x = self.ups[i][i_up](x)
|
||||
# AMP blocks
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
# Post-conv
|
||||
x = self.activation_post(x)
|
||||
x = self.conv_post(x)
|
||||
# Final tanh activation
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
try:
|
||||
print("Removing weight norm...")
|
||||
for l in self.ups:
|
||||
for l_i in l:
|
||||
remove_parametrizations(l_i, 'weight')
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_parametrizations(self.conv_pre, 'weight')
|
||||
remove_parametrizations(self.conv_post, 'weight')
|
||||
except ValueError:
|
||||
print("[INFO] Model already removed weight norm. Skipping!")
|
||||
pass
|
||||
|
||||
# Additional methods for huggingface_hub support
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
"""Save weights and config.json from a Pytorch model to a local directory."""
|
||||
|
||||
model_path = save_directory / "bigvgan_generator.pt"
|
||||
torch.save({"generator": self.state_dict()}, model_path)
|
||||
|
||||
config_path = save_directory / "config.json"
|
||||
with open(config_path, "w") as config_file:
|
||||
json.dump(self.h, config_file, indent=4)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
*,
|
||||
model_id: str,
|
||||
revision: str,
|
||||
cache_dir: str,
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict] = None,
|
||||
resume_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Union[str, bool, None] = None,
|
||||
map_location: str = "cpu", # Additional argument
|
||||
strict: bool = False, # Additional argument
|
||||
use_cuda_kernel: bool = False,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||
|
||||
# Download and load hyperparameters (h) used by BigVGAN
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading config.json from local directory")
|
||||
config_file = os.path.join(model_id, "config.json")
|
||||
else:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="config.json",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
h = load_hparams_from_json(config_file)
|
||||
|
||||
# instantiate BigVGAN using h
|
||||
if use_cuda_kernel:
|
||||
print(
|
||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
)
|
||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||
|
||||
# Download and load pretrained generator weight
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
||||
else:
|
||||
print(f"Loading weights from {model_id}")
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="bigvgan_generator.pt",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
checkpoint_dict = torch.load(model_file, map_location=map_location, weights_only=True)
|
||||
|
||||
try:
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
except RuntimeError:
|
||||
print(
|
||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
)
|
||||
model.remove_weight_norm()
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,18 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Jungil Kong
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Edward Dixon
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user