chore: remove all PrismAudio code from main branch
- Delete prismaudio_core/, data_utils/, scripts/, docs/plans/ - Delete PrismAudio nodes (feature_extractor, feature_loader, model_loader, sampler, text_only) - Delete PrismAudio workflows (video_to_audio, text_to_audio) - Clean nodes/utils.py: rename PRISMAUDIO_CATEGORY → SELVA_CATEGORY, remove unused helpers - Strip PrismAudio-only deps from requirements.txt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
+1
-1
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
|
||||
ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
"""
|
||||
PrismAudio feature extraction utilities.
|
||||
|
||||
Implements FeaturesUtils used by scripts/extract_features.py to extract:
|
||||
- Text features via T5-Gemma (transformers)
|
||||
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
|
||||
- Sync features via Synchformer visual encoder (PyTorch)
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FeaturesUtils:
|
||||
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
|
||||
self.device = device or torch.device("cpu")
|
||||
self._t5_tokenizer = None
|
||||
self._t5_encoder = None
|
||||
self._vp_model = None
|
||||
self._vp_state = None
|
||||
self._vp_text_tokenizer = None
|
||||
self._sync_model = None
|
||||
|
||||
self._synchformer_ckpt = synchformer_ckpt
|
||||
self._load_synchformer()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# T5-Gemma text encoding
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_t5(self):
|
||||
if self._t5_encoder is not None:
|
||||
return
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
model_id = "google/t5gemma-l-l-ul2-it"
|
||||
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
|
||||
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
self._t5_encoder = (
|
||||
AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
.get_encoder()
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
def encode_t5_text(self, texts):
|
||||
"""
|
||||
Args:
|
||||
texts: list of str
|
||||
Returns:
|
||||
Tensor [seq_len, 1024]
|
||||
"""
|
||||
self._ensure_t5()
|
||||
tokens = self._t5_tokenizer(
|
||||
texts, return_tensors="pt", padding=True
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
out = self._t5_encoder(**tokens)
|
||||
# Move encoder off GPU to save VRAM
|
||||
self._t5_encoder.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# VideoPrism video + text encoding (JAX)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _ensure_videoprism(self):
|
||||
if self._vp_model is not None:
|
||||
return
|
||||
from videoprism import models as vp
|
||||
import jax
|
||||
model_name = "videoprism_lvt_public_v1_large"
|
||||
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
|
||||
self._vp_model = vp.get_model(model_name)
|
||||
self._vp_state = vp.load_pretrained_weights(model_name)
|
||||
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
|
||||
jax_dev = jax.devices()[0]
|
||||
self._jax_forward = jax.jit(
|
||||
lambda x, y, z: self._vp_model.apply(
|
||||
self._vp_state, x, y, z, train=False, return_intermediate=True
|
||||
),
|
||||
device=jax_dev,
|
||||
)
|
||||
|
||||
def encode_video_and_text_with_videoprism(self, clip_input, texts):
|
||||
"""
|
||||
Args:
|
||||
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
|
||||
Returns:
|
||||
global_video_features: Tensor [1, D]
|
||||
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
|
||||
global_text_features: Tensor [1, D]
|
||||
"""
|
||||
self._ensure_videoprism()
|
||||
import jax.numpy as jnp
|
||||
from videoprism import models as vp
|
||||
|
||||
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
|
||||
frames = clip_input.squeeze(0) # [T, C, H, W]
|
||||
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
|
||||
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
||||
frames_np = frames.cpu().numpy().astype(np.float32)
|
||||
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
|
||||
|
||||
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
|
||||
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
|
||||
|
||||
# Joint video+text forward with intermediate outputs
|
||||
video_embeddings, text_embeddings, outputs = self._jax_forward(
|
||||
frames_jax, text_ids, text_paddings
|
||||
)
|
||||
|
||||
# Per-frame features: [B, T, 1024] L2-normalized
|
||||
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
|
||||
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
|
||||
|
||||
# Global video embedding: [1024] → [1, 1024]
|
||||
global_video = torch.from_numpy(
|
||||
np.array(video_embeddings[0])
|
||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||
|
||||
# Global text embedding: [1024] → [1, 1024]
|
||||
global_text = torch.from_numpy(
|
||||
np.array(text_embeddings[0])
|
||||
).unsqueeze(0).to(self.device) # [1, 1024]
|
||||
|
||||
return global_video, per_frame, global_text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synchformer sync feature encoding
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load_synchformer(self):
|
||||
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
|
||||
return
|
||||
|
||||
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
|
||||
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
|
||||
|
||||
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
|
||||
if isinstance(state, dict) and "model" in state:
|
||||
state_dict = state["model"]
|
||||
else:
|
||||
state_dict = state
|
||||
|
||||
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
|
||||
self._sync_model.eval()
|
||||
|
||||
def encode_video_with_sync(self, sync_input):
|
||||
"""
|
||||
Args:
|
||||
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
|
||||
Returns:
|
||||
sync_features: Tensor [num_segments, 768]
|
||||
"""
|
||||
if self._sync_model is None:
|
||||
raise RuntimeError(
|
||||
"[FeaturesUtils] Synchformer checkpoint not loaded. "
|
||||
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
|
||||
)
|
||||
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
|
||||
with torch.no_grad():
|
||||
return self._sync_model(frames)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synchformer visual encoder — TimeSformer-style ViT-B/16
|
||||
# Architecture reverse-engineered from synchformer_state_dict.pth
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class _PatchEmbed(nn.Module):
|
||||
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class _ViTAttn(nn.Module):
|
||||
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, D = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
|
||||
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
|
||||
|
||||
|
||||
class _BlockMLP(nn.Module):
|
||||
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
|
||||
def __init__(self, dim=768, mlp_dim=3072):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, mlp_dim)
|
||||
self.fc2 = nn.Linear(mlp_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _TimeSformerBlock(nn.Module):
|
||||
"""
|
||||
Factorized space-time attention block.
|
||||
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
|
||||
"""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn = _ViTAttn(dim, num_heads)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.timeattn = _ViTAttn(dim, num_heads)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.mlp = _BlockMLP(dim)
|
||||
|
||||
def forward(self, x, T):
|
||||
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
|
||||
x = x + self.attn(self.norm1(x))
|
||||
# Temporal attention: for each spatial position, attend across T frames
|
||||
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
|
||||
xt = x.permute(1, 0, 2)
|
||||
xt = xt + self.timeattn(self.norm3(xt))
|
||||
x = xt.permute(1, 0, 2)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _SpatialAttnAgg(nn.Module):
|
||||
"""
|
||||
Aggregates 196 spatial patches → 1 feature per frame using a
|
||||
TransformerEncoderLayer with a learnable CLS token.
|
||||
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
|
||||
"""
|
||||
def __init__(self, dim=768, num_heads=12):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
|
||||
self.linear1 = nn.Linear(dim, dim * 4)
|
||||
self.linear2 = nn.Linear(dim * 4, dim)
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [T, 196, 768] — spatial patches (CLS stripped)
|
||||
T = x.shape[0]
|
||||
cls = self.cls_token.expand(T, -1, -1)
|
||||
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
|
||||
xn = self.norm1(x)
|
||||
x = x + self.self_attn(xn, xn, xn)[0]
|
||||
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
|
||||
return x[:, 0, :] # [T, 768] — CLS per frame
|
||||
|
||||
|
||||
class _SynchformerVisualEncoder(nn.Module):
|
||||
"""
|
||||
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
|
||||
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.segment_frames = 8
|
||||
|
||||
self.patch_embed = _PatchEmbed()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
|
||||
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
|
||||
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
|
||||
self.norm = nn.LayerNorm(768)
|
||||
self.spatial_attn_agg = _SpatialAttnAgg()
|
||||
|
||||
# Load weights from vfeat_extractor.* prefix
|
||||
prefix = "vfeat_extractor."
|
||||
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
# Exclude 3D patch embed (we use 2D only)
|
||||
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
|
||||
missing, unexpected = self.load_state_dict(sub, strict=False)
|
||||
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
|
||||
if missing:
|
||||
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
|
||||
|
||||
self.to(device)
|
||||
|
||||
def forward(self, frames):
|
||||
"""
|
||||
Args:
|
||||
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
|
||||
Returns:
|
||||
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
seg = self.segment_frames
|
||||
num_seg = max(1, T // seg)
|
||||
T_aligned = num_seg * seg
|
||||
|
||||
results = []
|
||||
for i in range(num_seg):
|
||||
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
|
||||
results.append(self._forward_segment(chunk))
|
||||
return torch.cat(results, dim=0) # [T_aligned, 768]
|
||||
|
||||
def _forward_segment(self, x):
|
||||
# x: [8, 3, 224, 224]
|
||||
T = x.shape[0] # 8
|
||||
|
||||
# Patch embedding + CLS token
|
||||
x = self.patch_embed(x) # [8, 196, 768]
|
||||
cls = self.cls_token.expand(T, -1, -1)
|
||||
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
|
||||
|
||||
# Positional + temporal embeddings
|
||||
x = x + self.pos_embed # broadcast (1,197,768)
|
||||
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
|
||||
|
||||
# Transformer blocks (factorized space-time)
|
||||
for block in self.blocks:
|
||||
x = block(x, T)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
# Aggregate spatial patches → 1 feature per frame
|
||||
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
|
||||
@@ -1,194 +0,0 @@
|
||||
# ComfyUI-PrismAudio Design Document
|
||||
|
||||
**Date:** 2026-03-27
|
||||
**Status:** Approved
|
||||
|
||||
## Overview
|
||||
|
||||
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
|
||||
|
||||
## Architecture
|
||||
|
||||
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
ComfyUI-PrismAudio/
|
||||
├── __init__.py # Node registration
|
||||
├── nodes/
|
||||
│ ├── __init__.py
|
||||
│ ├── model_loader.py # PrismAudioModelLoader
|
||||
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
|
||||
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
|
||||
│ ├── sampler.py # PrismAudioSampler
|
||||
│ ├── text_only.py # PrismAudioTextOnly
|
||||
│ └── utils.py # Shared helpers
|
||||
├── prismaudio_core/ # Extracted inference code from PrismAudio
|
||||
│ ├── __init__.py
|
||||
│ ├── configs/
|
||||
│ │ └── prismaudio.json
|
||||
│ ├── models/ # DiT, conditioners, autoencoders, etc.
|
||||
│ ├── inference/ # sampling.py, generation.py
|
||||
│ └── factory.py # create_model_from_config
|
||||
├── scripts/
|
||||
│ ├── extract_features.py # Standalone VideoPrism feature extraction
|
||||
│ └── environment.yml # Conda env for extraction (JAX + TF)
|
||||
├── requirements.txt # PyTorch-only deps (no JAX/TF)
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Nodes
|
||||
|
||||
### PrismAudioModelLoader
|
||||
|
||||
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
|
||||
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
|
||||
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
|
||||
| **Output** | | |
|
||||
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
|
||||
|
||||
**Token resolution order** (no widget — env/CLI only for security):
|
||||
1. `HF_TOKEN` environment variable
|
||||
2. `huggingface-cli login` cached token
|
||||
3. None — fails on gated models with clear error message linking to license page
|
||||
|
||||
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
|
||||
|
||||
### PrismAudioFeatureLoader
|
||||
|
||||
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| npz_path | STRING | Path to .npz file |
|
||||
| **Output** | | |
|
||||
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
|
||||
|
||||
### PrismAudioFeatureExtractor
|
||||
|
||||
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| video | IMAGE | ComfyUI video frames tensor |
|
||||
| caption_cot | STRING | CoT description text |
|
||||
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
|
||||
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
|
||||
| **Output** | | |
|
||||
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
|
||||
|
||||
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
|
||||
|
||||
### PrismAudioSampler
|
||||
|
||||
Main generation node — takes model + features, produces audio.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
|
||||
| cot_description | STRING | Multiline CoT text |
|
||||
| duration | FLOAT | 1.0-30.0, defaults to video length |
|
||||
| steps | INT | 1-100, default 24 |
|
||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||
| seed | INT | Controls noise generation |
|
||||
| **Output** | | |
|
||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||
|
||||
**Pipeline:**
|
||||
1. Encode CoT text via T5-Gemma -> text_features
|
||||
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
|
||||
3. Compute latent_seq_len = round(44100 / 2048 * duration)
|
||||
4. Generate noise [1, 64, latent_seq_len] from seed
|
||||
5. Discrete Euler sampling (rectified flow) with CFG
|
||||
6. VAE decode -> stereo waveform at 44100 Hz
|
||||
7. Normalize to [-1, 1], return as AUDIO
|
||||
|
||||
### PrismAudioTextOnly
|
||||
|
||||
Text-to-audio without video input.
|
||||
|
||||
| Field | Type | Details |
|
||||
|-------|------|---------|
|
||||
| **Inputs** | | |
|
||||
| model | PRISMAUDIO_MODEL | From ModelLoader |
|
||||
| text_prompt | STRING | Text description |
|
||||
| duration | FLOAT | 1.0-30.0 |
|
||||
| steps | INT | 1-100, default 24 |
|
||||
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
|
||||
| seed | INT | Controls noise generation |
|
||||
| **Output** | | |
|
||||
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
|
||||
|
||||
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
|
||||
|
||||
## VRAM Management
|
||||
|
||||
Adaptive strategy using `comfy.model_management`:
|
||||
|
||||
| Available VRAM | Behavior |
|
||||
|---|---|
|
||||
| 24GB+ | Keep diffusion + VAE in VRAM |
|
||||
| 12-24GB | Sequential offload between stages |
|
||||
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
|
||||
| <8GB | Warn user, attempt with aggressive offload + fp16 |
|
||||
|
||||
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
|
||||
|
||||
## Feature Extraction Paths
|
||||
|
||||
### Path 1: Pre-computed .npz (FeatureLoader)
|
||||
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
|
||||
|
||||
### Path 2: Subprocess bridge (FeatureExtractor)
|
||||
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
|
||||
|
||||
### Path 3: Text-only (TextOnly node)
|
||||
No video features needed. T5-Gemma text encoding only (PyTorch-native).
|
||||
|
||||
## Dependencies
|
||||
|
||||
### ComfyUI environment (`requirements.txt`)
|
||||
```
|
||||
einops>=0.7.0
|
||||
safetensors
|
||||
huggingface_hub
|
||||
transformers>=4.52.3
|
||||
k-diffusion>=0.1.1
|
||||
```
|
||||
|
||||
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
|
||||
|
||||
### Extraction environment (`scripts/environment.yml`)
|
||||
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
|
||||
|
||||
## Model Files
|
||||
|
||||
Stored in `ComfyUI/models/prismaudio/`:
|
||||
|
||||
| File | Size | Source |
|
||||
|------|------|--------|
|
||||
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
|
||||
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
|
||||
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
|
||||
|
||||
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
|
||||
|
||||
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
|
||||
|
||||
## Design Decisions
|
||||
|
||||
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
|
||||
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
|
||||
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
|
||||
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
|
||||
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,207 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import subprocess
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
from .feature_loader import PrismAudioFeatureLoader
|
||||
|
||||
# Managed venv created automatically when python_env is left as default
|
||||
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
|
||||
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
|
||||
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
|
||||
|
||||
_EXTRACT_PACKAGES = [
|
||||
"torch", "torchaudio", "torchvision",
|
||||
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
|
||||
"tensorflow-cpu>=2.16.0",
|
||||
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
|
||||
"jax[cuda13]", "flax",
|
||||
"transformers", "decord", "einops", "numpy", "mediapy",
|
||||
"git+https://github.com/google-deepmind/videoprism.git",
|
||||
]
|
||||
|
||||
|
||||
def _pip_install(pip, *packages, label=None):
|
||||
"""Install one or more packages with visible output; raise on failure."""
|
||||
tag = label or packages[0]
|
||||
print(f"[PrismAudio] installing {tag} ...", flush=True)
|
||||
result = subprocess.run(
|
||||
[pip, "install", "--progress-bar", "on"] + list(packages),
|
||||
capture_output=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
|
||||
"See pip output above for details."
|
||||
)
|
||||
print(f"[PrismAudio] {tag} OK", flush=True)
|
||||
|
||||
|
||||
def _ensure_extract_env():
|
||||
"""Create and populate the managed venv on first use."""
|
||||
if os.path.exists(_MANAGED_PYTHON):
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
import shutil
|
||||
if os.path.exists(_MANAGED_VENV):
|
||||
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
|
||||
shutil.rmtree(_MANAGED_VENV)
|
||||
|
||||
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
|
||||
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
|
||||
|
||||
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
|
||||
|
||||
print("[PrismAudio] Upgrading pip...", flush=True)
|
||||
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
|
||||
|
||||
total = len(_EXTRACT_PACKAGES)
|
||||
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
|
||||
|
||||
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
|
||||
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
|
||||
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
|
||||
_pip_install(pip, pkg, label=label)
|
||||
|
||||
print("[PrismAudio] Feature-extraction env ready.", flush=True)
|
||||
return _MANAGED_PYTHON
|
||||
|
||||
|
||||
def _hash_inputs(video_tensor, cot_text):
|
||||
"""Create a hash of the inputs for caching."""
|
||||
h = hashlib.sha256()
|
||||
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
|
||||
h.update(cot_text.encode())
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def _save_frames_to_npy(video_tensor, output_path):
|
||||
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
|
||||
|
||||
Lossless — avoids H.264 encode/decode roundtrip.
|
||||
"""
|
||||
import numpy as np
|
||||
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
|
||||
np.save(output_path, frames_np)
|
||||
|
||||
|
||||
class PrismAudioFeatureExtractor:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"video": ("IMAGE",),
|
||||
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
|
||||
},
|
||||
"optional": {
|
||||
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
|
||||
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
|
||||
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
|
||||
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
|
||||
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
|
||||
RETURN_NAMES = ("features", "fps")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
|
||||
# Resolve fps from VHS video_info if connected
|
||||
if video_info is not None:
|
||||
fps = video_info["loaded_fps"]
|
||||
|
||||
# Resolve python binary
|
||||
if python_env == "comfyui_env":
|
||||
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
|
||||
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
|
||||
python_bin = sys.executable
|
||||
else:
|
||||
python_bin = _ensure_extract_env()
|
||||
|
||||
# Determine cache directory
|
||||
if not cache_dir:
|
||||
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Check cache
|
||||
cache_hash = _hash_inputs(video, caption_cot)
|
||||
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
|
||||
if os.path.exists(cached_path):
|
||||
print(f"[PrismAudio] Using cached features: {cached_path}")
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
|
||||
# Save frames to temp file (lossless .npy, no codec roundtrip)
|
||||
import time
|
||||
t0 = time.perf_counter()
|
||||
frames = video.shape[0]
|
||||
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
|
||||
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
|
||||
tmp_video = tmp.name
|
||||
_save_frames_to_npy(video, tmp_video)
|
||||
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||
|
||||
# Build subprocess command
|
||||
script_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"scripts", "extract_features.py"
|
||||
)
|
||||
|
||||
import folder_paths
|
||||
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
|
||||
if not os.path.exists(synchformer_ckpt):
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
|
||||
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
|
||||
)
|
||||
|
||||
cmd = [
|
||||
python_bin,
|
||||
script_path,
|
||||
"--video", tmp_video,
|
||||
"--cot_text", caption_cot,
|
||||
"--output", cached_path,
|
||||
"--source_fps", str(fps),
|
||||
"--synchformer_ckpt", synchformer_ckpt,
|
||||
]
|
||||
|
||||
# Build env: inherit current env, inject HF token if provided
|
||||
import copy
|
||||
env = copy.copy(os.environ)
|
||||
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
|
||||
if token:
|
||||
env["HF_TOKEN"] = token
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = token
|
||||
else:
|
||||
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
|
||||
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
|
||||
|
||||
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
|
||||
try:
|
||||
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=False,
|
||||
timeout=600, # 10 minute timeout
|
||||
env=env,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
|
||||
"See output above for details."
|
||||
)
|
||||
print("[PrismAudio] Feature extraction subprocess finished successfully.")
|
||||
finally:
|
||||
if os.path.exists(tmp_video):
|
||||
os.unlink(tmp_video)
|
||||
|
||||
# Load the extracted features
|
||||
loader = PrismAudioFeatureLoader()
|
||||
features, = loader.load_features(cached_path)
|
||||
return (features, float(fps))
|
||||
@@ -1,53 +0,0 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from .utils import PRISMAUDIO_CATEGORY
|
||||
|
||||
# Keys consumed by the conditioners (video_features, text_features, sync_features)
|
||||
# global_video_features and global_text_features are NOT consumed by any conditioner
|
||||
# in the prismaudio.json config — they are unused.
|
||||
REQUIRED_KEYS = [
|
||||
"video_features",
|
||||
"text_features",
|
||||
"sync_features",
|
||||
]
|
||||
|
||||
|
||||
class PrismAudioFeatureLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
|
||||
RETURN_NAMES = ("features",)
|
||||
FUNCTION = "load_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_features(self, npz_path):
|
||||
if not os.path.exists(npz_path):
|
||||
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
|
||||
|
||||
data = np.load(npz_path, allow_pickle=True)
|
||||
|
||||
features = {}
|
||||
for key in REQUIRED_KEYS:
|
||||
if key in data:
|
||||
features[key] = torch.from_numpy(data[key]).float()
|
||||
else:
|
||||
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
|
||||
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
|
||||
if key == "sync_features":
|
||||
features[key] = torch.zeros(8, 768)
|
||||
else:
|
||||
features[key] = torch.zeros(1, 1024)
|
||||
|
||||
# Load duration if present
|
||||
if "duration" in data:
|
||||
features["duration"] = float(data["duration"])
|
||||
|
||||
return (features,)
|
||||
@@ -1,154 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
|
||||
get_device, get_offload_device, determine_precision, determine_offload_strategy,
|
||||
soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
|
||||
# HuggingFace repo for auto-download
|
||||
HF_REPO_ID = "FunAudioLLM/PrismAudio"
|
||||
REQUIRED_FILES = {
|
||||
"diffusion": "prismaudio.ckpt",
|
||||
"vae": "vae.ckpt",
|
||||
"synchformer": "synchformer_state_dict.pth",
|
||||
}
|
||||
|
||||
|
||||
def _download_if_missing(filename, model_dir, hf_token=None):
|
||||
"""Download a model file from HuggingFace if not present locally."""
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
if os.path.exists(filepath):
|
||||
return filepath
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
|
||||
try:
|
||||
downloaded = hf_hub_download(
|
||||
repo_id=HF_REPO_ID,
|
||||
filename=filename,
|
||||
local_dir=model_dir,
|
||||
token=hf_token or None,
|
||||
)
|
||||
return downloaded
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
|
||||
raise RuntimeError(
|
||||
f"[PrismAudio] Model '{filename}' requires license acceptance. "
|
||||
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
|
||||
f"then set HF_TOKEN env var or run: huggingface-cli login"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
class PrismAudioModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
register_model_folder()
|
||||
return {
|
||||
"required": {
|
||||
"precision": (["auto", "fp32", "fp16", "bf16"],),
|
||||
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def load_model(self, precision, offload_strategy):
|
||||
device = get_device()
|
||||
dtype = determine_precision(precision, device)
|
||||
strategy = determine_offload_strategy(offload_strategy)
|
||||
token = resolve_hf_token()
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
|
||||
# Auto-download missing files
|
||||
for key, filename in REQUIRED_FILES.items():
|
||||
_download_if_missing(filename, model_dir, hf_token=token)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"prismaudio_core", "configs", "prismaudio.json"
|
||||
)
|
||||
with open(config_path) as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Create model from config
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
model = create_model_from_config(model_config)
|
||||
|
||||
# Load diffusion weights
|
||||
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
|
||||
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
|
||||
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
|
||||
if "state_dict" in diffusion_state:
|
||||
diffusion_state = diffusion_state["state_dict"]
|
||||
diff_result = model.load_state_dict(diffusion_state, strict=False)
|
||||
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
|
||||
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
|
||||
if diff_result.missing_keys:
|
||||
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
|
||||
if diff_result.unexpected_keys:
|
||||
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
|
||||
# Sample a few ckpt keys to verify prefix alignment
|
||||
sample_keys = list(diffusion_state.keys())[:5]
|
||||
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
|
||||
|
||||
# Load VAE weights separately
|
||||
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
|
||||
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
|
||||
vae_full_state = comfy.utils.load_torch_file(vae_path)
|
||||
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
|
||||
# Sample raw keys to see actual prefix
|
||||
vae_sample_keys = list(vae_full_state.keys())[:8]
|
||||
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
|
||||
# Strip "autoencoder." prefix from keys
|
||||
vae_state = {}
|
||||
prefix = "autoencoder."
|
||||
for k, v in vae_full_state.items():
|
||||
if k.startswith(prefix):
|
||||
vae_state[k[len(prefix):]] = v
|
||||
else:
|
||||
vae_state[k] = v
|
||||
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
|
||||
# Sample model keys to compare
|
||||
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
|
||||
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
|
||||
# strict=False: vae.ckpt is a training checkpoint that also contains
|
||||
# discriminator, loss modules, and EMA wrappers not present in the
|
||||
# inference AudioAutoencoder — ignore those extra keys.
|
||||
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
|
||||
# (AutoencoderPretransform.load_state_dict doesn't return the result)
|
||||
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
|
||||
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
|
||||
if vae_result.missing_keys:
|
||||
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
|
||||
|
||||
# Apply precision: DiT + conditioners in user-selected dtype,
|
||||
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
|
||||
model.model.to(dtype) # DiTWrapper
|
||||
model.conditioner.to(dtype) # MultiConditioner
|
||||
# model.pretransform stays in fp32
|
||||
|
||||
if strategy == "keep_in_vram":
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.to(get_offload_device())
|
||||
|
||||
model.eval()
|
||||
|
||||
return ({
|
||||
"model": model,
|
||||
"dtype": dtype,
|
||||
"strategy": strategy,
|
||||
"config": model_config,
|
||||
"model_dir": model_dir,
|
||||
},)
|
||||
@@ -1,165 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache,
|
||||
)
|
||||
|
||||
|
||||
class PrismAudioSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"features": ("PRISMAUDIO_FEATURES",),
|
||||
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, features, duration, steps, cfg_scale, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
# Resolve duration: 0 means use video duration from features
|
||||
if duration <= 0:
|
||||
if "duration" not in features:
|
||||
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
|
||||
duration = features["duration"]
|
||||
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
|
||||
|
||||
# Compute latent dimensions
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Note: no seq length config needed — the model adapts to input tensor shapes
|
||||
# dynamically via its transformer architecture.
|
||||
|
||||
# Determine if video features are present (not all zeros)
|
||||
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
|
||||
|
||||
video_feat = features["video_features"].to(device, dtype=dtype)
|
||||
sync_feat = features["sync_features"].to(device, dtype=dtype)
|
||||
|
||||
# Build metadata as a TUPLE of dicts (one per batch sample)
|
||||
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
|
||||
sample_meta = {
|
||||
"video_features": video_feat,
|
||||
"text_features": features["text_features"].to(device, dtype=dtype),
|
||||
"sync_features": sync_feat,
|
||||
"video_exist": torch.tensor(has_video),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
# Move model to device if offloaded
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
# Run conditioning
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Handle missing video: substitute learned empty embeddings
|
||||
if not has_video:
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
# Assemble conditioning inputs for the DiT
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Sample with progress bar
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
# Offload diffusion model and conditioner before VAE decode
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
# Offload VAE
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp (matching reference: div by max abs before clamp)
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
|
||||
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
def _substitute_empty_features(diffusion, conditioning, device, dtype):
|
||||
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
|
||||
|
||||
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
|
||||
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
|
||||
near-zero activations, NOT the learned null signal the model was trained with.
|
||||
|
||||
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
|
||||
"""
|
||||
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
|
||||
|
||||
# Substitute video_features with learned empty_clip_feat
|
||||
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
|
||||
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['video_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['video_features'][0] = empty_expanded
|
||||
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
|
||||
# Substitute sync_features with learned empty_sync_feat
|
||||
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
|
||||
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
|
||||
batch_size = conditioning['sync_features'][0].shape[0]
|
||||
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
|
||||
conditioning['sync_features'][0] = empty_expanded
|
||||
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
|
||||
_CLIP_SIZE = 384
|
||||
@@ -68,7 +68,7 @@ class SelvaFeatureExtractor:
|
||||
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
|
||||
RETURN_NAMES = ("features", "fps", "prompt")
|
||||
FUNCTION = "extract_features"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
|
||||
duration=0.0, cache_dir=""):
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
import torch
|
||||
import folder_paths
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
|
||||
|
||||
# Variant → (generator filename, mode, has_bigvgan)
|
||||
_VARIANTS = {
|
||||
@@ -96,7 +96,7 @@ class SelvaModelLoader:
|
||||
RETURN_TYPES = ("SELVA_MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def load_model(self, variant, precision, offload_strategy):
|
||||
from selva_core.model.networks_generator import get_my_mmaudio
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
|
||||
|
||||
|
||||
class SelvaSampler:
|
||||
@@ -35,7 +35,7 @@ class SelvaSampler:
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
CATEGORY = SELVA_CATEGORY
|
||||
|
||||
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
|
||||
from selva_core.model.flow_matching import FlowMatching
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
import torch
|
||||
import comfy.model_management as mm
|
||||
import comfy.utils
|
||||
|
||||
from .utils import (
|
||||
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
|
||||
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
|
||||
)
|
||||
from .sampler import _substitute_empty_features
|
||||
|
||||
|
||||
class PrismAudioTextOnly:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("PRISMAUDIO_MODEL",),
|
||||
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
|
||||
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
|
||||
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
|
||||
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
RETURN_NAMES = ("audio",)
|
||||
FUNCTION = "generate"
|
||||
CATEGORY = PRISMAUDIO_CATEGORY
|
||||
|
||||
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
|
||||
device = get_device()
|
||||
dtype = model["dtype"]
|
||||
strategy = model["strategy"]
|
||||
diffusion = model["model"]
|
||||
|
||||
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
|
||||
|
||||
# Encode text with T5-Gemma
|
||||
text_features = _encode_text_t5(text_prompt, device, dtype)
|
||||
|
||||
# Build metadata: tuple of one dict per sample
|
||||
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
|
||||
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
|
||||
# These will be substituted with learned empty embeddings after conditioning
|
||||
sample_meta = {
|
||||
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
|
||||
"text_features": text_features.to(device, dtype=dtype),
|
||||
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
|
||||
"video_exist": torch.tensor(False),
|
||||
}
|
||||
metadata = (sample_meta,)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(device)
|
||||
diffusion.conditioner.to(device)
|
||||
soft_empty_cache()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
|
||||
conditioning = diffusion.conditioner(metadata, device)
|
||||
|
||||
# Substitute empty features for video/sync
|
||||
_substitute_empty_features(diffusion, conditioning, device, dtype)
|
||||
|
||||
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
|
||||
|
||||
# Generate noise from seed (MPS doesn't support torch.Generator)
|
||||
gen_device = "cpu" if device.type == "mps" else device
|
||||
generator = torch.Generator(device=gen_device).manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
[1, IO_CHANNELS, latent_length],
|
||||
generator=generator,
|
||||
device=gen_device,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
from prismaudio_core.inference.sampling import sample_discrete_euler
|
||||
|
||||
def on_step(info):
|
||||
pbar.update(1)
|
||||
|
||||
fakes = sample_discrete_euler(
|
||||
diffusion.model,
|
||||
noise,
|
||||
steps,
|
||||
callback=on_step,
|
||||
**cond_inputs,
|
||||
cfg_scale=cfg_scale,
|
||||
batch_cfg=True,
|
||||
)
|
||||
|
||||
fakes_f = fakes.float()
|
||||
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.model.to(get_offload_device())
|
||||
diffusion.conditioner.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
diffusion.pretransform.to(device)
|
||||
|
||||
# VAE decode in fp32 (snake activations overflow in fp16)
|
||||
with torch.amp.autocast(device_type=device.type, enabled=False):
|
||||
audio = diffusion.pretransform.decode(fakes_f)
|
||||
|
||||
if strategy == "offload_to_cpu":
|
||||
diffusion.pretransform.to(get_offload_device())
|
||||
soft_empty_cache()
|
||||
|
||||
# Peak normalize then clamp
|
||||
audio = audio.float()
|
||||
pre_norm_std = audio.std().item()
|
||||
pre_norm_peak = audio.abs().max().item()
|
||||
peak = audio.abs().max().clamp(min=1e-8)
|
||||
audio = (audio / peak).clamp(-1, 1)
|
||||
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
|
||||
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
|
||||
|
||||
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
|
||||
|
||||
|
||||
# T5-Gemma encoder singleton
|
||||
_t5_model = None
|
||||
_t5_tokenizer = None
|
||||
|
||||
|
||||
def _encode_text_t5(text, device, dtype):
|
||||
"""Encode text using T5-Gemma.
|
||||
|
||||
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
|
||||
FeaturesUtils.encode_t5_text() implementation.
|
||||
No truncation applied (matching reference behavior).
|
||||
"""
|
||||
global _t5_model, _t5_tokenizer
|
||||
|
||||
if _t5_model is None:
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
model_id = "google/t5gemma-l-l-ul2-it"
|
||||
token = resolve_hf_token()
|
||||
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
|
||||
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
||||
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
|
||||
_t5_model.eval()
|
||||
|
||||
_t5_model.to(device, dtype=dtype)
|
||||
|
||||
tokens = _t5_tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = _t5_model(**tokens)
|
||||
|
||||
# Move T5 off GPU after encoding to save VRAM
|
||||
_t5_model.to("cpu")
|
||||
soft_empty_cache()
|
||||
|
||||
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
|
||||
+4
-47
@@ -1,21 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
|
||||
PRISMAUDIO_CATEGORY = "PrismAudio"
|
||||
SAMPLE_RATE = 44100
|
||||
DOWNSAMPLING_RATIO = 2048
|
||||
IO_CHANNELS = 64
|
||||
|
||||
def get_prismaudio_model_dir():
|
||||
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
def register_model_folder():
|
||||
model_dir = get_prismaudio_model_dir()
|
||||
folder_paths.add_model_folder_path("prismaudio", model_dir)
|
||||
SELVA_CATEGORY = "SelVA"
|
||||
|
||||
def get_device():
|
||||
return mm.get_torch_device()
|
||||
@@ -23,42 +9,13 @@ def get_device():
|
||||
def get_offload_device():
|
||||
return mm.unet_offload_device()
|
||||
|
||||
def get_free_memory(device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
return mm.get_free_memory(device)
|
||||
|
||||
def soft_empty_cache():
|
||||
mm.soft_empty_cache()
|
||||
|
||||
def determine_precision(preference, device):
|
||||
if preference != "auto":
|
||||
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
|
||||
if device.type == "cpu":
|
||||
return torch.float32
|
||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
return torch.bfloat16
|
||||
return torch.float16
|
||||
|
||||
def determine_offload_strategy(preference):
|
||||
if preference != "auto":
|
||||
return preference
|
||||
free_mem = get_free_memory()
|
||||
gb = free_mem / (1024 ** 3)
|
||||
if gb >= 24:
|
||||
free_mem = mm.get_free_memory(get_device())
|
||||
if free_mem / (1024 ** 3) >= 16:
|
||||
return "keep_in_vram"
|
||||
else:
|
||||
return "offload_to_cpu"
|
||||
|
||||
def try_import_flash_attn():
|
||||
try:
|
||||
import flash_attn
|
||||
return flash_attn
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
def resolve_hf_token():
|
||||
env_token = os.environ.get("HF_TOKEN")
|
||||
if env_token:
|
||||
return env_token
|
||||
return None
|
||||
return "offload_to_cpu"
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
PrismAudio core inference modules.
|
||||
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
|
||||
Only inference-critical code — no training, no JAX/TF dependencies.
|
||||
"""
|
||||
@@ -1,141 +0,0 @@
|
||||
{
|
||||
"model_type": "diffusion_cond",
|
||||
"sample_size": 397312,
|
||||
"sample_rate": 44100,
|
||||
"audio_channels": 2,
|
||||
"model": {
|
||||
"pretransform": {
|
||||
"type": "autoencoder",
|
||||
"iterate_batch": true,
|
||||
"config": {
|
||||
"encoder": {
|
||||
"type": "oobleck",
|
||||
"config": {
|
||||
"in_channels": 2,
|
||||
"channels": 128,
|
||||
"c_mults": [1, 2, 4, 8, 16],
|
||||
"strides": [2, 4, 4, 8, 8],
|
||||
"latent_dim": 128,
|
||||
"use_snake": true
|
||||
}
|
||||
},
|
||||
"decoder": {
|
||||
"type": "oobleck",
|
||||
"config": {
|
||||
"out_channels": 2,
|
||||
"channels": 128,
|
||||
"c_mults": [1, 2, 4, 8, 16],
|
||||
"strides": [2, 4, 4, 8, 8],
|
||||
"latent_dim": 64,
|
||||
"use_snake": true,
|
||||
"final_tanh": false
|
||||
}
|
||||
},
|
||||
"bottleneck": {
|
||||
"type": "vae"
|
||||
},
|
||||
"latent_dim": 64,
|
||||
"downsampling_ratio": 2048,
|
||||
"io_channels": 2
|
||||
}
|
||||
},
|
||||
"conditioning": {
|
||||
"configs": [
|
||||
{
|
||||
"id": "video_features",
|
||||
"type": "cond_mlp",
|
||||
"config": {
|
||||
"dim": 1024,
|
||||
"output_dim": 1024
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "text_features",
|
||||
"type": "cond_mlp",
|
||||
"config": {
|
||||
"dim": 1024,
|
||||
"output_dim": 1024
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "sync_features",
|
||||
"type": "sync_mlp",
|
||||
"config": {
|
||||
"dim": 768,
|
||||
"output_dim": 1024
|
||||
}
|
||||
}
|
||||
],
|
||||
"cond_dim": 768
|
||||
},
|
||||
"diffusion": {
|
||||
"cross_attention_cond_ids": ["video_features","text_features"],
|
||||
"add_cond_ids": ["video_features"],
|
||||
"sync_cond_ids": ["sync_features"],
|
||||
"type": "dit",
|
||||
"diffusion_objective": "rectified_flow",
|
||||
"config": {
|
||||
"io_channels": 64,
|
||||
"embed_dim": 1024,
|
||||
"depth": 24,
|
||||
"num_heads": 16,
|
||||
"cond_token_dim": 1024,
|
||||
"add_token_dim": 1024,
|
||||
"sync_token_dim": 1024,
|
||||
"project_cond_tokens": false,
|
||||
"transformer_type": "continuous_transformer",
|
||||
"attn_kwargs":{
|
||||
"qk_norm": "rns"
|
||||
},
|
||||
"use_gated": true,
|
||||
"use_sync_gated": true
|
||||
}
|
||||
},
|
||||
"io_channels": 64
|
||||
},
|
||||
"training": {
|
||||
"use_ema": true,
|
||||
"log_loss_info": false,
|
||||
"cfg_dropout_prob": 0.1,
|
||||
"pre_encoded": true,
|
||||
"timestep_sampler": "trunc_logit_normal",
|
||||
"optimizer_configs": {
|
||||
"diffusion": {
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"config": {
|
||||
"lr": 1e-4,
|
||||
"betas": [0.9, 0.999],
|
||||
"weight_decay": 1e-3
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "InverseLR",
|
||||
"config": {
|
||||
"inv_gamma": 100000,
|
||||
"power": 0.5,
|
||||
"warmup": 0.99
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"demo": {
|
||||
"demo_every": 5000,
|
||||
"demo_steps": 24,
|
||||
"num_demos": 10,
|
||||
"demo_cond": [
|
||||
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
|
||||
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
|
||||
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
|
||||
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
|
||||
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
|
||||
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
|
||||
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
|
||||
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
|
||||
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
|
||||
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
|
||||
],
|
||||
"demo_cfg_scales": [5]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,413 +0,0 @@
|
||||
"""
|
||||
Model factory functions for PrismAudio inference.
|
||||
|
||||
Extracted from:
|
||||
- PrismAudio/models/factory.py
|
||||
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
|
||||
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
|
||||
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
|
||||
|
||||
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
|
||||
Only inference-critical factory functions are retained.
|
||||
"""
|
||||
|
||||
import json
|
||||
import typing as tp
|
||||
from typing import Dict, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_model_from_config(model_config):
|
||||
model_type = model_config.get('model_type', None)
|
||||
|
||||
assert model_type is not None, 'model_type must be specified in model config'
|
||||
|
||||
if model_type == 'autoencoder':
|
||||
return create_autoencoder_from_config(model_config)
|
||||
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
|
||||
return create_diffusion_cond_from_config(model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
|
||||
def create_pretransform_from_config(pretransform_config, sample_rate):
|
||||
pretransform_type = pretransform_config.get('type', None)
|
||||
|
||||
assert pretransform_type is not None, 'type must be specified in pretransform config'
|
||||
|
||||
if pretransform_type == 'autoencoder':
|
||||
from prismaudio_core.models.pretransforms import AutoencoderPretransform
|
||||
|
||||
# Create fake top-level config to pass sample rate to autoencoder constructor
|
||||
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
||||
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
|
||||
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
||||
|
||||
scale = pretransform_config.get("scale", 1.0)
|
||||
model_half = pretransform_config.get("model_half", False)
|
||||
iterate_batch = pretransform_config.get("iterate_batch", False)
|
||||
chunked = pretransform_config.get("chunked", False)
|
||||
|
||||
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
|
||||
elif pretransform_type == 'wavelet':
|
||||
raise NotImplementedError("wavelet pretransform type is not supported")
|
||||
elif pretransform_type == 'pqmf':
|
||||
from prismaudio_core.models.pretransforms import PQMFPretransform
|
||||
pqmf_config = pretransform_config["config"]
|
||||
pretransform = PQMFPretransform(**pqmf_config)
|
||||
elif pretransform_type == 'dac_pretrained':
|
||||
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
|
||||
pretrained_dac_config = pretransform_config["config"]
|
||||
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
||||
elif pretransform_type == "audiocraft_pretrained":
|
||||
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
|
||||
|
||||
audiocraft_config = pretransform_config["config"]
|
||||
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
|
||||
|
||||
enable_grad = pretransform_config.get('enable_grad', False)
|
||||
pretransform.enable_grad = enable_grad
|
||||
|
||||
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
||||
|
||||
return pretransform
|
||||
|
||||
|
||||
def create_bottleneck_from_config(bottleneck_config):
|
||||
bottleneck_type = bottleneck_config.get('type', None)
|
||||
|
||||
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
|
||||
|
||||
if bottleneck_type == 'tanh':
|
||||
from prismaudio_core.models.bottleneck import TanhBottleneck
|
||||
bottleneck = TanhBottleneck()
|
||||
elif bottleneck_type == 'vae':
|
||||
from prismaudio_core.models.bottleneck import VAEBottleneck
|
||||
bottleneck = VAEBottleneck()
|
||||
elif bottleneck_type == 'rvq':
|
||||
from prismaudio_core.models.bottleneck import RVQBottleneck
|
||||
|
||||
quantizer_params = {
|
||||
"dim": 128,
|
||||
"codebook_size": 1024,
|
||||
"num_quantizers": 8,
|
||||
"decay": 0.99,
|
||||
"kmeans_init": True,
|
||||
"kmeans_iters": 50,
|
||||
"threshold_ema_dead_code": 2,
|
||||
}
|
||||
|
||||
quantizer_params.update(bottleneck_config["config"])
|
||||
|
||||
bottleneck = RVQBottleneck(**quantizer_params)
|
||||
elif bottleneck_type == "dac_rvq":
|
||||
from prismaudio_core.models.bottleneck import DACRVQBottleneck
|
||||
|
||||
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
||||
|
||||
elif bottleneck_type == 'rvq_vae':
|
||||
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
|
||||
|
||||
quantizer_params = {
|
||||
"dim": 128,
|
||||
"codebook_size": 1024,
|
||||
"num_quantizers": 8,
|
||||
"decay": 0.99,
|
||||
"kmeans_init": True,
|
||||
"kmeans_iters": 50,
|
||||
"threshold_ema_dead_code": 2,
|
||||
}
|
||||
|
||||
quantizer_params.update(bottleneck_config["config"])
|
||||
|
||||
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
||||
|
||||
elif bottleneck_type == 'dac_rvq_vae':
|
||||
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
|
||||
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
||||
elif bottleneck_type == 'l2_norm':
|
||||
from prismaudio_core.models.bottleneck import L2Bottleneck
|
||||
bottleneck = L2Bottleneck()
|
||||
elif bottleneck_type == "wasserstein":
|
||||
from prismaudio_core.models.bottleneck import WassersteinBottleneck
|
||||
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
||||
elif bottleneck_type == "fsq":
|
||||
from prismaudio_core.models.bottleneck import FSQBottleneck
|
||||
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
|
||||
|
||||
requires_grad = bottleneck_config.get('requires_grad', True)
|
||||
if not requires_grad:
|
||||
for param in bottleneck.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return bottleneck
|
||||
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
"""Create an AudioAutoencoder from a config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/autoencoders.py.
|
||||
"""
|
||||
from prismaudio_core.models.autoencoders import (
|
||||
AudioAutoencoder,
|
||||
create_encoder_from_config,
|
||||
create_decoder_from_config,
|
||||
)
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
|
||||
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
|
||||
"""Create a MultiConditioner from a conditioning config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/conditioners.py.
|
||||
"""
|
||||
from prismaudio_core.models.conditioners import (
|
||||
MultiConditioner,
|
||||
T5Conditioner,
|
||||
CLAPTextConditioner,
|
||||
CLIPTextConditioner,
|
||||
MetaCLIPTextConditioner,
|
||||
CLAPAudioConditioner,
|
||||
Cond_MLP,
|
||||
Global_MLP,
|
||||
Sync_MLP,
|
||||
Cond_MLP_1,
|
||||
Cond_ConvMLP,
|
||||
Cond_MLP_Global,
|
||||
Cond_MLP_Global_1,
|
||||
Cond_MLP_Global_2,
|
||||
Video_Global,
|
||||
Video_Sync,
|
||||
Text_Linear,
|
||||
CLIPConditioner,
|
||||
IntConditioner,
|
||||
NumberConditioner,
|
||||
PhonemeConditioner,
|
||||
TokenizerLUTConditioner,
|
||||
PretransformConditioner,
|
||||
mm_unchang,
|
||||
)
|
||||
from prismaudio_core.models.utils import load_ckpt_state_dict
|
||||
|
||||
conditioners = {}
|
||||
cond_dim = config["cond_dim"]
|
||||
|
||||
default_keys = config.get("default_keys", {})
|
||||
|
||||
for conditioner_info in config["configs"]:
|
||||
id = conditioner_info["id"]
|
||||
|
||||
conditioner_type = conditioner_info["type"]
|
||||
|
||||
conditioner_config = {"output_dim": cond_dim}
|
||||
|
||||
conditioner_config.update(conditioner_info["config"])
|
||||
if conditioner_type == "t5":
|
||||
conditioners[id] = T5Conditioner(**conditioner_config)
|
||||
elif conditioner_type == "clap_text":
|
||||
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "clip_text":
|
||||
conditioners[id] = CLIPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "metaclip_text":
|
||||
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
|
||||
elif conditioner_type == "clap_audio":
|
||||
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp":
|
||||
conditioners[id] = Cond_MLP(**conditioner_config)
|
||||
elif conditioner_type == "global_mlp":
|
||||
conditioners[id] = Global_MLP(**conditioner_config)
|
||||
elif conditioner_type == "sync_mlp":
|
||||
conditioners[id] = Sync_MLP(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_1":
|
||||
conditioners[id] = Cond_MLP_1(**conditioner_config)
|
||||
elif conditioner_type == "cond_convmlp":
|
||||
conditioners[id] = Cond_ConvMLP(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global":
|
||||
conditioners[id] = Cond_MLP_Global(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global_1":
|
||||
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
|
||||
elif conditioner_type == "cond_mlp_global_2":
|
||||
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
|
||||
elif conditioner_type == "video_global":
|
||||
conditioners[id] = Video_Global(**conditioner_config)
|
||||
elif conditioner_type == "video_sync":
|
||||
conditioners[id] = Video_Sync(**conditioner_config)
|
||||
elif conditioner_type == "text_linear":
|
||||
conditioners[id] = Text_Linear(**conditioner_config)
|
||||
elif conditioner_type == "video_clip":
|
||||
conditioners[id] = CLIPConditioner(**conditioner_config)
|
||||
elif conditioner_type == "int":
|
||||
conditioners[id] = IntConditioner(**conditioner_config)
|
||||
elif conditioner_type == "number":
|
||||
conditioners[id] = NumberConditioner(**conditioner_config)
|
||||
elif conditioner_type == "phoneme":
|
||||
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
||||
elif conditioner_type == "lut":
|
||||
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
||||
elif conditioner_type == "pretransform":
|
||||
sample_rate = conditioner_config.pop("sample_rate", None)
|
||||
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
|
||||
|
||||
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
|
||||
|
||||
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
||||
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
|
||||
|
||||
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
|
||||
elif conditioner_type == "mm_unchang":
|
||||
conditioners[id] = mm_unchang(**conditioner_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
||||
|
||||
return MultiConditioner(conditioners, default_keys=default_keys)
|
||||
|
||||
|
||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
|
||||
|
||||
Originally in PrismAudio/models/diffusion.py.
|
||||
"""
|
||||
from prismaudio_core.models.diffusion import (
|
||||
ConditionedDiffusionModelWrapper,
|
||||
MMConditionedDiffusionModelWrapper,
|
||||
UNetCFG1DWrapper,
|
||||
UNet1DCondWrapper,
|
||||
DiTWrapper,
|
||||
)
|
||||
|
||||
model_config = config["model"]
|
||||
|
||||
model_type = config["model_type"]
|
||||
|
||||
diffusion_config = model_config.get('diffusion', None)
|
||||
assert diffusion_config is not None, "Must specify diffusion config"
|
||||
|
||||
diffusion_model_type = diffusion_config.get('type', None)
|
||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||
|
||||
diffusion_model_config = diffusion_config.get('config', None)
|
||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||
|
||||
if diffusion_model_type == 'adp_cfg_1d':
|
||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'adp_1d':
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'mmdit':
|
||||
raise NotImplementedError("mmdit diffusion model type is not supported")
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
|
||||
sample_rate = config.get('sample_rate', None)
|
||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||
|
||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||
|
||||
conditioning_config = model_config.get('conditioning', None)
|
||||
|
||||
conditioner = None
|
||||
if conditioning_config is not None:
|
||||
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
||||
|
||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||
zero_init = diffusion_config.get('zero_init', False)
|
||||
pretransform = model_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
min_input_length *= diffusion_model.model.patch_size
|
||||
|
||||
# Get the proper wrapper class
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if model_type == "mm_diffusion_cond":
|
||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||
|
||||
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
elif model_type == "diffusion_prior":
|
||||
raise NotImplementedError("diffusion_prior model type is not supported")
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
min_input_length=min_input_length,
|
||||
sample_rate=sample_rate,
|
||||
cross_attn_cond_ids=cross_attention_ids,
|
||||
global_cond_ids=global_cond_ids,
|
||||
input_concat_ids=input_concat_ids,
|
||||
prepend_cond_ids=prepend_cond_ids,
|
||||
add_cond_ids=add_cond_ids,
|
||||
sync_cond_ids=sync_cond_ids,
|
||||
pretransform=pretransform,
|
||||
io_channels=io_channels,
|
||||
zero_init=zero_init,
|
||||
**extra_kwargs
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
from .sampling import sample_discrete_euler
|
||||
from .utils import set_audio_channels, prepare_audio
|
||||
|
||||
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
|
||||
@@ -1,29 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
|
||||
"""Discrete Euler sampler for rectified flow, with optional callback.
|
||||
|
||||
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
|
||||
Original uses tqdm internally.
|
||||
|
||||
Args:
|
||||
model: The diffusion model (DiTWrapper)
|
||||
x: Initial noise tensor [B, C, T]
|
||||
steps: Number of sampling steps
|
||||
sigma_max: Maximum sigma (default 1.0 for rectified flow)
|
||||
callback: Optional callable({"i": step, "x": current_x}) for progress
|
||||
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
|
||||
sync_cond, cfg_scale, batch_cfg, etc.
|
||||
"""
|
||||
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
|
||||
|
||||
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
|
||||
dt = t_next - t_curr
|
||||
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
|
||||
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
||||
if callback is not None:
|
||||
callback({"i": i, "x": x})
|
||||
|
||||
return x
|
||||
@@ -1,62 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchaudio import transforms as T
|
||||
|
||||
|
||||
def set_audio_channels(audio, target_channels):
|
||||
"""Convert audio tensor to target number of channels.
|
||||
|
||||
Args:
|
||||
audio: Audio tensor of shape [B, C, T]
|
||||
target_channels: Desired number of channels (1 for mono, 2 for stereo)
|
||||
|
||||
Returns:
|
||||
Audio tensor with the target number of channels.
|
||||
"""
|
||||
if target_channels == 1:
|
||||
# Convert to mono
|
||||
audio = audio.mean(1, keepdim=True)
|
||||
elif target_channels == 2:
|
||||
# Convert to stereo
|
||||
if audio.shape[1] == 1:
|
||||
audio = audio.repeat(1, 2, 1)
|
||||
elif audio.shape[1] > 2:
|
||||
audio = audio[:, :2, :]
|
||||
return audio
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Resample, pad/trim, and convert channels of an audio tensor.
|
||||
|
||||
Args:
|
||||
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
|
||||
in_sr: Input sample rate
|
||||
target_sr: Target sample rate
|
||||
target_length: Target length in samples (padded or cropped)
|
||||
target_channels: Target number of channels
|
||||
device: Torch device to place the audio on
|
||||
|
||||
Returns:
|
||||
Audio tensor of shape [B, target_channels, target_length] on device.
|
||||
"""
|
||||
audio = audio.to(device)
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
# Add batch dimension
|
||||
if audio.dim() == 1:
|
||||
audio = audio.unsqueeze(0).unsqueeze(0)
|
||||
elif audio.dim() == 2:
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
# Pad or crop to target_length
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[:, :, :target_length]
|
||||
|
||||
audio = set_audio_channels(audio, target_channels)
|
||||
|
||||
return audio
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
PrismAudio model modules for inference.
|
||||
|
||||
Re-exports create_model_from_config from the factory module.
|
||||
"""
|
||||
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
|
||||
__all__ = ["create_model_from_config"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,821 +0,0 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchaudio import transforms as T
|
||||
from alias_free_torch import Activation1d
|
||||
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||
from typing import Literal, Dict, Any
|
||||
|
||||
from .blocks import SnakeBeta
|
||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||
from .pretransforms import Pretransform
|
||||
from .utils import checkpoint
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
||||
import torchaudio.transforms as T
|
||||
import torch
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
if audio.shape[0] > target_channels:
|
||||
audio = audio[:target_channels]
|
||||
elif audio.shape[0] < target_channels:
|
||||
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
||||
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[..., :target_length]
|
||||
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
|
||||
def _lazy_create_bottleneck_from_config(bottleneck):
|
||||
from prismaudio_core.factory import create_bottleneck_from_config
|
||||
return create_bottleneck_from_config(bottleneck)
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class DACEncoderWrapper(nn.Module):
|
||||
def __init__(self, in_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Encoder as DACEncoder
|
||||
|
||||
latent_dim = kwargs.pop("latent_dim", None)
|
||||
|
||||
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||
|
||||
if in_channels != 1:
|
||||
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class DACDecoderWrapper(nn.Module):
|
||||
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Decoder as DACDecoder
|
||||
|
||||
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder,
|
||||
latent_dim,
|
||||
downsampling_ratio,
|
||||
sample_rate,
|
||||
io_channels=2,
|
||||
bottleneck: Bottleneck = None,
|
||||
pretransform: Pretransform = None,
|
||||
in_channels = None,
|
||||
out_channels = None,
|
||||
soft_clip = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsampling_ratio = downsampling_ratio
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.io_channels = io_channels
|
||||
self.in_channels = io_channels
|
||||
self.out_channels = io_channels
|
||||
|
||||
self.min_length = self.downsampling_ratio
|
||||
|
||||
if in_channels is not None:
|
||||
self.in_channels = in_channels
|
||||
|
||||
if out_channels is not None:
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.pretransform = pretransform
|
||||
|
||||
self.soft_clip = soft_clip
|
||||
|
||||
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||
|
||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||
|
||||
info = {}
|
||||
|
||||
if self.pretransform is not None and not skip_pretransform:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
|
||||
if self.encoder is not None:
|
||||
if iterate_batch:
|
||||
latents = []
|
||||
for i in range(audio.shape[0]):
|
||||
latents.append(self.encoder(audio[i:i+1]))
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.encoder(audio)
|
||||
else:
|
||||
latents = audio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||
|
||||
info.update(bottleneck_info)
|
||||
|
||||
if return_info:
|
||||
return latents, info
|
||||
|
||||
return latents
|
||||
|
||||
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||
|
||||
if self.bottleneck is not None:
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||
latents = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.decoder(latents[i:i+1]))
|
||||
decoded = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
decoded = self.decoder(latents, **kwargs)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(decoded.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(latents.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
if self.soft_clip:
|
||||
decoded = torch.tanh(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
'''
|
||||
Decode discrete tokens to audio
|
||||
Only works with discrete autoencoders
|
||||
'''
|
||||
|
||||
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||
|
||||
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
|
||||
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||
'''
|
||||
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||
If the model is mono, stereo audio will be converted to mono.
|
||||
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||
Audio will be resampled to the model's sample rate.
|
||||
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||
'''
|
||||
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||
|
||||
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||
'''
|
||||
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||
The audio in that list can be of different lengths and channels.
|
||||
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||
All audio will be resampled to the model's sample rate.
|
||||
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||
If the model is mono, all audio will be converted to mono.
|
||||
The output will be a tensor of shape (Batch x Channels x Length)
|
||||
'''
|
||||
batch_size = len(audio_list)
|
||||
if isinstance(in_sr_list, int):
|
||||
in_sr_list = [in_sr_list]*batch_size
|
||||
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||
new_audio = []
|
||||
max_length = 0
|
||||
# resample & find the max length
|
||||
for i in range(batch_size):
|
||||
audio = audio_list[i]
|
||||
in_sr = in_sr_list[i]
|
||||
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||
# batchsize 1 was given by accident. Just squeeze it.
|
||||
audio = audio.squeeze(0)
|
||||
elif len(audio.shape) == 1:
|
||||
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||
# Resample audio
|
||||
if in_sr != self.sample_rate:
|
||||
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||
audio = resample_tf(audio)
|
||||
new_audio.append(audio)
|
||||
if audio.shape[-1] > max_length:
|
||||
max_length = audio.shape[-1]
|
||||
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||
for i in range(batch_size):
|
||||
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||
# convert to tensor
|
||||
return torch.stack(new_audio)
|
||||
|
||||
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||
# and therefore you likely could use the same values with decode_audio.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Encode the entire audio in parallel
|
||||
return self.encode(audio, **kwargs)
|
||||
else:
|
||||
# CHUNKED ENCODING
|
||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
total_size = audio.shape[2] # in samples
|
||||
batch_size = audio.shape[0]
|
||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||
overlap *= samples_per_latent # converting metric in latents to samples
|
||||
hop_size = chunk_size - overlap
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = audio[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = audio[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||
# because we can encode audio of varying lengths
|
||||
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||
y_size = total_size // samples_per_latent
|
||||
# Create an empty latent, we will populate it with chunks as we encode them
|
||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# encode the chunk
|
||||
y_chunk = self.encode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size // samples_per_latent
|
||||
t_end = t_start + chunk_size // samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = overlap//samples_per_latent//2
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Decode latents to audio.
|
||||
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Decode the entire latent in parallel
|
||||
return self.decode(latents, **kwargs)
|
||||
else:
|
||||
# chunked decoding
|
||||
hop_size = chunk_size - overlap
|
||||
total_size = latents.shape[2]
|
||||
batch_size = latents.shape[0]
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = latents[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = latents[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# samples_per_latent is just the downsampling ratio
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
# Create an empty waveform, we will populate it with chunks as decode them
|
||||
y_size = total_size * samples_per_latent
|
||||
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# decode the chunk
|
||||
y_chunk = self.decode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size * samples_per_latent
|
||||
t_end = t_start + chunk_size * samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = (overlap//2) * samples_per_latent
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
|
||||
class DiffusionAutoencoder(AudioAutoencoder):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion: ConditionedDiffusionModel,
|
||||
diffusion_downsampling_ratio,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.diffusion = diffusion
|
||||
|
||||
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||
|
||||
if self.encoder is not None:
|
||||
# Shrink the initial encoder parameters to avoid saturated latents
|
||||
with torch.no_grad():
|
||||
for param in self.encoder.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def decode(self, latents, steps=100):
|
||||
|
||||
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if self.decoder is not None:
|
||||
latents = self.decoder(latents)
|
||||
|
||||
# Upsample latents to match diffusion length
|
||||
if latents.shape[2] != upsampled_length:
|
||||
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||
|
||||
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||
from prismaudio_core.inference.sampling import sample
|
||||
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
# AE factories
|
||||
|
||||
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||
encoder_type = encoder_config.get("type", None)
|
||||
assert encoder_type is not None, "Encoder type must be specified"
|
||||
|
||||
if encoder_type == "oobleck":
|
||||
encoder = OobleckEncoder(
|
||||
**encoder_config["config"]
|
||||
)
|
||||
|
||||
elif encoder_type == "seanet":
|
||||
from encodec.modules import SEANetEncoder
|
||||
seanet_encoder_config = encoder_config["config"]
|
||||
|
||||
#SEANet encoder expects strides in reverse order
|
||||
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||
encoder = SEANetEncoder(
|
||||
**seanet_encoder_config
|
||||
)
|
||||
elif encoder_type == "dac":
|
||||
dac_config = encoder_config["config"]
|
||||
|
||||
encoder = DACEncoderWrapper(**dac_config)
|
||||
elif encoder_type == "local_attn":
|
||||
from .local_attention import TransformerEncoder1D
|
||||
|
||||
local_attn_config = encoder_config["config"]
|
||||
|
||||
encoder = TransformerEncoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||
|
||||
requires_grad = encoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return encoder
|
||||
|
||||
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||
decoder_type = decoder_config.get("type", None)
|
||||
assert decoder_type is not None, "Decoder type must be specified"
|
||||
|
||||
if decoder_type == "oobleck":
|
||||
decoder = OobleckDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "seanet":
|
||||
from encodec.modules import SEANetDecoder
|
||||
|
||||
decoder = SEANetDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "dac":
|
||||
dac_config = decoder_config["config"]
|
||||
|
||||
decoder = DACDecoderWrapper(**dac_config)
|
||||
elif decoder_type == "local_attn":
|
||||
from .local_attention import TransformerDecoder1D
|
||||
|
||||
local_attn_config = decoder_config["config"]
|
||||
|
||||
decoder = TransformerDecoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||
|
||||
requires_grad = decoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return decoder
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||
|
||||
diffae_config = config["model"]
|
||||
|
||||
if "encoder" in diffae_config:
|
||||
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||
else:
|
||||
encoder = None
|
||||
|
||||
if "decoder" in diffae_config:
|
||||
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||
|
||||
latent_dim = diffae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = diffae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
bottleneck = diffae_config.get("bottleneck", None)
|
||||
|
||||
pretransform = diffae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
diffusion_downsampling_ratio = None
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion_downsampling_ratio = 1
|
||||
|
||||
return DiffusionAutoencoder(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
diffusion=diffusion,
|
||||
io_channels=io_channels,
|
||||
sample_rate=sample_rate,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform
|
||||
)
|
||||
@@ -1,331 +0,0 @@
|
||||
from functools import reduce
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
from packaging import version
|
||||
|
||||
from dac.nn.layers import Snake1d
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
class ResConvBlock(ResidualBlock):
|
||||
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
||||
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
||||
super().__init__([
|
||||
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_mid),
|
||||
Snake1d(c_mid) if use_snake else nn.GELU(),
|
||||
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
||||
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
||||
], skip)
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm = nn.GroupNorm(1, c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||
|
||||
if not self.use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
# Use flash attention for A100 GPUs
|
||||
self.sdp_kernel_config = (True, False, False)
|
||||
else:
|
||||
# Don't use flash attention for other GPUs
|
||||
self.sdp_kernel_config = (False, True, True)
|
||||
|
||||
def forward(self, input):
|
||||
n, c, s = input.shape
|
||||
qkv = self.qkv_proj(self.norm(input))
|
||||
qkv = qkv.view(
|
||||
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3]**-0.25
|
||||
|
||||
if self.use_flash:
|
||||
with sdp_kernel(*self.sdp_kernel_config):
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
||||
else:
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
||||
|
||||
|
||||
return input + self.dropout(self.out_proj(y))
|
||||
|
||||
class SkipBlock(nn.Module):
|
||||
def __init__(self, *main):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.cat([self.main(input), input], dim=1)
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.weight = nn.Parameter(torch.randn(
|
||||
[out_features // 2, in_features]) * std)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
def expand_to_planes(input, shape):
|
||||
return input[..., None].repeat([1, 1, shape[2]])
|
||||
|
||||
_kernels = {
|
||||
'linear':
|
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
'cubic':
|
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
'lanczos3':
|
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||
}
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv1d(x, weight, stride=2)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
def Downsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
||||
) -> nn.Module:
|
||||
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
||||
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * kernel_multiplier + 1,
|
||||
stride=factor,
|
||||
padding=factor * (kernel_multiplier // 2),
|
||||
)
|
||||
|
||||
|
||||
def Upsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
||||
) -> nn.Module:
|
||||
|
||||
if factor == 1:
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
if use_nearest:
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor=factor, mode="nearest"),
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
return nn.ConvTranspose1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * 2,
|
||||
stride=factor,
|
||||
padding=factor // 2 + factor % 2,
|
||||
output_padding=factor % 2,
|
||||
)
|
||||
|
||||
def zero_init(layer):
|
||||
nn.init.zeros_(layer.weight)
|
||||
if layer.bias is not None:
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
class AdaRMSNorm(nn.Module):
|
||||
def __init__(self, features, cond_features, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps},"
|
||||
|
||||
def forward(self, x, cond):
|
||||
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
||||
|
||||
def normalize(x, eps=1e-4):
|
||||
dim = list(range(1, x.ndim))
|
||||
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
||||
alpha = np.sqrt(n.numel() / x.numel())
|
||||
return x / torch.add(eps, n, alpha=alpha)
|
||||
|
||||
class ForcedWNConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
with torch.no_grad():
|
||||
self.weight.copy_(normalize(self.weight))
|
||||
|
||||
fan_in = self.weight[0].numel()
|
||||
|
||||
w = normalize(self.weight) / math.sqrt(fan_in)
|
||||
|
||||
return F.conv1d(x, w, padding='same')
|
||||
|
||||
# Kernels
|
||||
|
||||
use_compile = True
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
if not use_compile:
|
||||
return function
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
|
||||
@compile
|
||||
def linear_geglu(x, weight, bias=None):
|
||||
x = x @ weight.mT
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
@compile
|
||||
def rms_norm(x, scale, eps):
|
||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||
return x * scale.to(x.dtype)
|
||||
|
||||
# Layers
|
||||
|
||||
class LinearGEGLU(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super().__init__(in_features, out_features * 2, bias=bias)
|
||||
self.out_features = out_features
|
||||
|
||||
def forward(self, x):
|
||||
return linear_geglu(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("scale", torch.ones(shape))
|
||||
else:
|
||||
self.scale = nn.Parameter(torch.ones(shape))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.scale, self.eps)
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# try:
|
||||
# snake_beta = torch.compile(snake_beta)
|
||||
# except RuntimeError:
|
||||
# pass
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
# License available in LICENSES/LICENSE_NVIDIA.txt
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
@@ -1,355 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, is_discrete: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
class DiscreteBottleneck(Bottleneck):
|
||||
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||
super().__init__(is_discrete=True)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.tokens_id = tokens_id
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class TanhBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = torch.tanh(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def compute_mean_kernel(x, y):
|
||||
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||
return torch.exp(-kernel_input).mean()
|
||||
|
||||
def compute_mmd(latents):
|
||||
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||
noise = torch.randn_like(latents_reshaped)
|
||||
|
||||
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||
noise_kernel = compute_mean_kernel(noise, noise)
|
||||
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||
|
||||
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||
return mmd.mean()
|
||||
|
||||
class WassersteinBottleneck(Bottleneck):
|
||||
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
self.bypass_mmd = bypass_mmd
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
if self.training and return_info:
|
||||
if self.bypass_mmd:
|
||||
mmd = torch.tensor(0.0)
|
||||
else:
|
||||
mmd = compute_mmd(x)
|
||||
|
||||
info["mmd"] = mmd
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class L2Bottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = F.normalize(x, dim=1)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return F.normalize(x, dim=1)
|
||||
|
||||
class RVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
info["kl"] = kl
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class FSQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
x = x.to(orig_dtype)
|
||||
|
||||
# Reorder indices to match the expected format
|
||||
indices = rearrange(indices, "b n q -> b q n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
latents = self.quantizer.indices_to_codes(tokens)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,884 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import typing as tp
|
||||
|
||||
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
||||
from .conditioners import MultiConditioner
|
||||
from .dit import DiffusionTransformer
|
||||
from .pretransforms import Pretransform
|
||||
|
||||
from .adp import UNetCFG1d, UNet1d
|
||||
|
||||
# Lazy imports for factory functions to avoid circular imports
|
||||
def _get_create_pretransform_from_config():
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config
|
||||
|
||||
def _get_create_multi_conditioner_from_conditioning_config():
|
||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||
return create_multi_conditioner_from_conditioning_config
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class DiffusionModelWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DiffusionModel,
|
||||
io_channels,
|
||||
sample_size,
|
||||
sample_rate,
|
||||
min_input_length,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.io_channels = io_channels
|
||||
self.sample_size = sample_size
|
||||
self.sample_rate = sample_rate
|
||||
self.min_input_length = min_input_length
|
||||
|
||||
self.model = model
|
||||
|
||||
if pretransform is not None:
|
||||
self.pretransform = pretransform
|
||||
else:
|
||||
self.pretransform = None
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class ConditionedDiffusionModel(nn.Module):
|
||||
def __init__(self,
|
||||
*args,
|
||||
supports_cross_attention: bool = False,
|
||||
supports_input_concat: bool = False,
|
||||
supports_global_cond: bool = False,
|
||||
supports_prepend_cond: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_cross_attention = supports_cross_attention
|
||||
self.supports_input_concat = supports_input_concat
|
||||
self.supports_global_cond = supports_global_cond
|
||||
self.supports_prepend_cond = supports_prepend_cond
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cross_attn_cond: torch.Tensor = None,
|
||||
cross_attn_mask: torch.Tensor = None,
|
||||
input_concat_cond: torch.Tensor = None,
|
||||
global_embed: torch.Tensor = None,
|
||||
prepend_cond: torch.Tensor = None,
|
||||
prepend_cond_mask: torch.Tensor = None,
|
||||
cfg_scale: float = 1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
**kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class ConditionedDiffusionModelWrapper(nn.Module):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: ConditionedDiffusionModel,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
zero_init: bool = False,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
sync_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.sync_cond_ids = sync_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
if zero_init is True:
|
||||
self.conditioner.apply(_basic_init)
|
||||
self.model.model.initialize_weights()
|
||||
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
sync_input = None
|
||||
|
||||
if len(self.cross_attn_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
cross_attention_input = []
|
||||
cross_attention_masks = []
|
||||
|
||||
for key in self.cross_attn_cond_ids:
|
||||
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(cross_attn_in.shape) == 2:
|
||||
cross_attn_in = cross_attn_in.unsqueeze(1)
|
||||
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
||||
|
||||
cross_attention_input.append(cross_attn_in)
|
||||
cross_attention_masks.append(cross_attn_mask)
|
||||
|
||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||
|
||||
if len(self.add_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
add_input = []
|
||||
|
||||
for key in self.add_cond_ids:
|
||||
add_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(add_in.shape) == 2:
|
||||
add_in = add_in.unsqueeze(1)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
add_input.append(add_in)
|
||||
|
||||
add_input = torch.cat(add_input, dim=2)
|
||||
|
||||
if len(self.sync_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
sync_input = []
|
||||
|
||||
for key in self.sync_cond_ids:
|
||||
sync_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(sync_in.shape) == 2:
|
||||
sync_in = sync_in.unsqueeze(1)
|
||||
sync_input.append(sync_in)
|
||||
|
||||
sync_input = torch.cat(sync_input, dim=2)
|
||||
|
||||
if len(self.global_cond_ids) > 0:
|
||||
# Concatenate all global conditioning inputs over the channel dimension
|
||||
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||
global_conds = []
|
||||
for key in self.global_cond_ids:
|
||||
global_cond_input = conditioning_tensors[key][0]
|
||||
if len(global_cond_input.shape) == 2:
|
||||
global_cond_input = global_cond_input.unsqueeze(1)
|
||||
global_conds.append(global_cond_input)
|
||||
|
||||
# # Concatenate over the channel dimension
|
||||
# if global_conds[0].shape[-1] == 768:
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
# else:
|
||||
# global_cond = sum(global_conds)
|
||||
global_cond = sum(global_conds)
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
|
||||
if len(global_cond.shape) == 3:
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
if len(self.input_concat_ids) > 0:
|
||||
# Concatenate all input concat conditioning inputs over the channel dimension
|
||||
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
||||
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
||||
|
||||
if len(self.prepend_cond_ids) > 0:
|
||||
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||
prepend_conds = []
|
||||
prepend_cond_masks = []
|
||||
|
||||
for key in self.prepend_cond_ids:
|
||||
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
||||
if len(prepend_cond_input.shape) == 2:
|
||||
prepend_cond_input = prepend_cond_input.unsqueeze(1)
|
||||
prepend_conds.append(prepend_cond_input)
|
||||
prepend_cond_masks.append(prepend_cond_mask)
|
||||
|
||||
prepend_cond = torch.cat(prepend_conds, dim=1)
|
||||
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
||||
|
||||
if negative:
|
||||
return {
|
||||
"negative_cross_attn_cond": cross_attention_input,
|
||||
"negative_cross_attn_mask": cross_attention_masks,
|
||||
"negative_global_cond": global_cond,
|
||||
"negative_input_concat_cond": input_concat_cond
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"cross_attn_cond": cross_attention_input,
|
||||
"cross_attn_mask": cross_attention_masks,
|
||||
"global_cond": global_cond,
|
||||
"input_concat_cond": input_concat_cond,
|
||||
"prepend_cond": prepend_cond,
|
||||
"prepend_cond_mask": prepend_cond_mask,
|
||||
"add_cond": add_input,
|
||||
"sync_cond": sync_input
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNetCFG1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
**kwargs):
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
embedding=cross_attn_cond,
|
||||
embedding_mask=cross_attn_mask,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
embedding_scale=cfg_scale,
|
||||
embedding_mask_proba=cfg_dropout_prob,
|
||||
batch_cfg=batch_cfg,
|
||||
rescale_cfg=rescale_cfg,
|
||||
negative_embedding=negative_cross_attn_cond,
|
||||
negative_embedding_mask=negative_cross_attn_mask,
|
||||
**kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNet1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
**kwargs):
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
if input_concat_cond.shape[2] != x.shape[2]:
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
**kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
class UNet1DUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = in_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
||||
|
||||
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
**kwargs):
|
||||
|
||||
return self.model(x, t, cond = input_concat_cond)
|
||||
|
||||
class DiffusionAttnUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels = 2,
|
||||
depth=14,
|
||||
n_attn_layers = 6,
|
||||
channels = [128, 128, 256, 256] + [512] * 10,
|
||||
cond_dim = 0,
|
||||
cond_noise_aug = False,
|
||||
kernel_size = 5,
|
||||
learned_resample = False,
|
||||
strides = [2] * 13,
|
||||
conv_bias = True,
|
||||
use_snake = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_noise_aug = cond_noise_aug
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
if self.cond_noise_aug:
|
||||
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||
|
||||
self.timestep_embed = FourierFeatures(1, 16)
|
||||
|
||||
attn_layer = depth - n_attn_layers
|
||||
|
||||
strides = [1] + strides
|
||||
|
||||
block = nn.Identity()
|
||||
|
||||
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
||||
|
||||
for i in range(depth, 0, -1):
|
||||
c = channels[i - 1]
|
||||
stride = strides[i-1]
|
||||
if stride > 2 and not learned_resample:
|
||||
raise ValueError("Must have stride 2 without learned resampling")
|
||||
|
||||
if i > 1:
|
||||
c_prev = channels[i - 2]
|
||||
add_attn = i >= attn_layer and n_attn_layers > 0
|
||||
block = SkipBlock(
|
||||
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
||||
conv_block(c_prev, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
block,
|
||||
conv_block(c * 2 if i != depth else c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c_prev),
|
||||
SelfAttention1d(c_prev, c_prev //
|
||||
32) if add_attn else nn.Identity(),
|
||||
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
||||
block = nn.Sequential(
|
||||
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, c),
|
||||
block,
|
||||
conv_block(c * 2, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, io_channels, is_last=True),
|
||||
)
|
||||
self.net = block
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.net.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
||||
|
||||
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
||||
|
||||
inputs = [x, timestep_embed]
|
||||
|
||||
if cond is not None:
|
||||
if cond.shape[2] != x.shape[2]:
|
||||
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
||||
|
||||
if self.cond_noise_aug:
|
||||
# Get a random number between 0 and 1, uniformly sampled
|
||||
if cond_aug_scale is None:
|
||||
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
||||
else:
|
||||
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
||||
|
||||
# Add noise to the conditioning signal
|
||||
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
||||
|
||||
# Get embedding for noise cond level, reusing timestamp_embed
|
||||
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
||||
|
||||
inputs.append(aug_level_embed)
|
||||
|
||||
inputs.append(cond)
|
||||
|
||||
outputs = self.net(torch.cat(inputs, dim=1))
|
||||
|
||||
return outputs
|
||||
|
||||
class DiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = DiffusionTransformer(*args, **kwargs)
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
global_cond=None,
|
||||
negative_global_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_mask,
|
||||
negative_cross_attn_cond=negative_cross_attn_cond,
|
||||
negative_cross_attn_mask=negative_cross_attn_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
global_embed=global_cond,
|
||||
**kwargs)
|
||||
|
||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
mm_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
self.mm_cond_ids = mm_cond_ids
|
||||
|
||||
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
|
||||
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
inpaint_masked_input = None
|
||||
t5_features = None
|
||||
metaclip_global_text_features = None
|
||||
clip_f = conditioning_tensors["metaclip_features"]
|
||||
sync_f = conditioning_tensors["sync_features"]
|
||||
text_f = conditioning_tensors["metaclip_text_features"]
|
||||
if 'inpaint_masked_input' in conditioning_tensors.keys():
|
||||
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
|
||||
if 't5_features' in conditioning_tensors.keys():
|
||||
t5_features = conditioning_tensors["t5_features"]
|
||||
if 'metaclip_global_text_features' in conditioning_tensors.keys():
|
||||
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
|
||||
return {
|
||||
"clip_f": clip_f,
|
||||
"sync_f": sync_f,
|
||||
"text_f": text_f,
|
||||
"inpaint_masked_input": inpaint_masked_input,
|
||||
"t5_features": t5_features,
|
||||
"metaclip_global_text_features": metaclip_global_text_features
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class DiTUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
model_type = diffusion_uncond_config.get('type', None)
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('config', {})
|
||||
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
model = UNet1DUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('diffusion', {})
|
||||
model_type = diffusion_config.get('type', None)
|
||||
model_config = diffusion_config.get("config",{})
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**model_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
io_channels = model_config.get("io_channels", 64)
|
||||
model = UNet1DUncondWrapper(
|
||||
io_channels = io_channels,
|
||||
**model_config
|
||||
)
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**model_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
|
||||
model_config = config["model"]
|
||||
|
||||
model_type = config["model_type"]
|
||||
|
||||
diffusion_config = model_config.get('diffusion', None)
|
||||
assert diffusion_config is not None, "Must specify diffusion config"
|
||||
|
||||
diffusion_model_type = diffusion_config.get('type', None)
|
||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||
|
||||
diffusion_model_config = diffusion_config.get('config', None)
|
||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||
|
||||
if diffusion_model_type == 'adp_cfg_1d':
|
||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'adp_1d':
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
|
||||
sample_rate = config.get('sample_rate', None)
|
||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||
|
||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||
|
||||
conditioning_config = model_config.get('conditioning', None)
|
||||
|
||||
conditioner = None
|
||||
if conditioning_config is not None:
|
||||
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
|
||||
|
||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||
zero_init = diffusion_config.get('zero_init', False)
|
||||
pretransform = model_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
min_input_length *= diffusion_model.model.patch_size
|
||||
|
||||
# Get the proper wrapper class
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if model_type == "mm_diffusion_cond":
|
||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||
|
||||
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
min_input_length=min_input_length,
|
||||
sample_rate=sample_rate,
|
||||
cross_attn_cond_ids=cross_attention_ids,
|
||||
global_cond_ids=global_cond_ids,
|
||||
input_concat_ids=input_concat_ids,
|
||||
prepend_cond_ids=prepend_cond_ids,
|
||||
add_cond_ids=add_cond_ids,
|
||||
sync_cond_ids=sync_cond_ids,
|
||||
pretransform=pretransform,
|
||||
io_channels=io_channels,
|
||||
zero_init=zero_init,
|
||||
**extra_kwargs
|
||||
)
|
||||
@@ -1,539 +0,0 @@
|
||||
import typing as tp
|
||||
import math
|
||||
import torch
|
||||
# from beartype.typing import Tuple
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from .blocks import FourierFeatures
|
||||
from .transformer import ContinuousTransformer
|
||||
from .utils import mask_from_frac_lengths, resample
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
io_channels=32,
|
||||
patch_size=1,
|
||||
embed_dim=768,
|
||||
cond_token_dim=0,
|
||||
project_cond_tokens=True,
|
||||
global_cond_dim=0,
|
||||
project_global_cond=True,
|
||||
input_concat_dim=0,
|
||||
prepend_cond_dim=0,
|
||||
cond_ctx_dim=0,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
|
||||
add_token_dim=0,
|
||||
sync_token_dim=0,
|
||||
use_mlp=False,
|
||||
use_zero_init=False,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cond_token_dim = cond_token_dim
|
||||
|
||||
# Timestep embeddings
|
||||
timestep_features_dim = 256
|
||||
# Timestep embeddings
|
||||
self.timestep_cond_type = timestep_cond_type
|
||||
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
||||
|
||||
if timestep_cond_type == "global":
|
||||
timestep_embed_dim = embed_dim
|
||||
elif timestep_cond_type == "input_concat":
|
||||
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
|
||||
input_concat_dim += timestep_embed_dim
|
||||
|
||||
self.to_timestep_embed = nn.Sequential(
|
||||
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=True),
|
||||
)
|
||||
self.use_mlp = use_mlp
|
||||
if cond_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_cond_embed = nn.Sequential(
|
||||
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 0
|
||||
|
||||
if global_cond_dim > 0:
|
||||
# Global conditioning
|
||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||
self.to_global_embed = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
||||
)
|
||||
if add_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_add_embed = nn.Sequential(
|
||||
nn.Linear(add_token_dim, add_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
add_embed_dim = 0
|
||||
|
||||
if sync_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_sync_embed = nn.Sequential(
|
||||
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
sync_embed_dim = 0
|
||||
|
||||
|
||||
if prepend_cond_dim > 0:
|
||||
# Prepend conditioning
|
||||
self.to_prepend_embed = nn.Sequential(
|
||||
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
)
|
||||
|
||||
self.input_concat_dim = input_concat_dim
|
||||
|
||||
dim_in = io_channels + self.input_concat_dim
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Transformer
|
||||
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.global_cond_type = global_cond_type
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
|
||||
global_dim = None
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
# The global conditioning is projected to the embed_dim already at this point
|
||||
global_dim = embed_dim
|
||||
|
||||
self.transformer = ContinuousTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
dim_heads=embed_dim // num_heads,
|
||||
dim_in=dim_in * patch_size,
|
||||
dim_out=io_channels * patch_size,
|
||||
cross_attend = cond_token_dim > 0,
|
||||
cond_token_dim = cond_embed_dim,
|
||||
global_cond_dim=global_dim,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||
|
||||
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
||||
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
||||
nn.init.zeros_(self.preprocess_conv.weight)
|
||||
nn.init.zeros_(self.postprocess_conv.weight)
|
||||
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
# if isinstance(module, nn.Conv1d):
|
||||
# if module.bias is not None:
|
||||
# nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
|
||||
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
if self.global_cond_type == "adaLN":
|
||||
for block in self.transformer.layers:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
nn.init.constant_(self.empty_clip_feat, 0)
|
||||
nn.init.constant_(self.empty_sync_feat, 0)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
mask=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
add_masks=None,
|
||||
sync_cond=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||
|
||||
if global_embed is not None:
|
||||
# Project the global conditioning to the embedding dimension
|
||||
global_embed = self.to_global_embed(global_embed)
|
||||
|
||||
prepend_inputs = None
|
||||
prepend_mask = None
|
||||
prepend_length = 0
|
||||
if prepend_cond is not None:
|
||||
# Project the prepend conditioning to the embedding dimension
|
||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||
|
||||
prepend_inputs = prepend_cond
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_mask = prepend_cond_mask
|
||||
|
||||
if input_concat_cond is not None:
|
||||
# reshape from (b, n, c) to (b, c, n)
|
||||
if input_concat_cond.shape[1] != x.shape[1]:
|
||||
input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
# if input_concat_cond.shape[1] != x.shape[2]:
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# if len(global_embed.shape) == 2:
|
||||
# global_embed = global_embed.unsqueeze(1)
|
||||
# global_embed = global_embed + input_concat_cond
|
||||
x = torch.cat([x, input_concat_cond], dim=1)
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if self.timestep_cond_type == "global":
|
||||
if global_embed is not None:
|
||||
if len(global_embed.shape) == 3:
|
||||
timestep_embed = timestep_embed.unsqueeze(1)
|
||||
global_embed = global_embed + timestep_embed
|
||||
else:
|
||||
global_embed = timestep_embed
|
||||
elif self.timestep_cond_type == "input_concat":
|
||||
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
|
||||
|
||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||
if self.global_cond_type == "prepend" and global_embed is not None:
|
||||
if prepend_inputs is None:
|
||||
# Prepend inputs are just the global embed, and the mask is all ones
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = global_embed.unsqueeze(1)
|
||||
else:
|
||||
prepend_inputs = global_embed
|
||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||
else:
|
||||
# Prepend inputs are the prepend conditioning + the global embed
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||
else:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
|
||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||
|
||||
prepend_length = prepend_inputs.shape[1]
|
||||
|
||||
x = self.preprocess_conv(x) + x
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
|
||||
extra_args = {}
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
extra_args["global_cond"] = global_embed
|
||||
|
||||
if self.patch_size > 1:
|
||||
b, seq_len, c = x.shape
|
||||
|
||||
# 计算需要填充的数量
|
||||
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
|
||||
|
||||
if pad_amount > 0:
|
||||
# 在时间维度上进行填充
|
||||
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
|
||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||
|
||||
if add_cond is not None:
|
||||
# Interpolate add_cond to the same length as x
|
||||
# if self.use_mlp:
|
||||
add_cond = self.to_add_embed(add_cond)
|
||||
if add_cond.shape[1] != x.shape[1]:
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
# add_cond = resample(add_cond, x)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = self.to_sync_embed(sync_cond)
|
||||
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||
|
||||
if return_info:
|
||||
output, info = output
|
||||
|
||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||
|
||||
if self.patch_size > 1:
|
||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||
# 移除之前添加的填充
|
||||
if pad_amount > 0:
|
||||
output = output[:, :, :seq_len]
|
||||
|
||||
output = self.postprocess_conv(output) + output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
negative_global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
sync_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=0.0,
|
||||
causal=False,
|
||||
scale_phi=0.0,
|
||||
mask=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
||||
bsz, a, b = x.shape
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
t = t.to(model_dtype)
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = cross_attn_cond.to(model_dtype)
|
||||
|
||||
if negative_cross_attn_cond is not None:
|
||||
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
|
||||
|
||||
if input_concat_cond is not None:
|
||||
input_concat_cond = input_concat_cond.to(model_dtype)
|
||||
|
||||
if global_embed is not None:
|
||||
global_embed = global_embed.to(model_dtype)
|
||||
|
||||
if negative_global_embed is not None:
|
||||
negative_global_embed = negative_global_embed.to(model_dtype)
|
||||
|
||||
if prepend_cond is not None:
|
||||
prepend_cond = prepend_cond.to(model_dtype)
|
||||
|
||||
if add_cond is not None:
|
||||
add_cond = add_cond.to(model_dtype)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = sync_cond.to(model_dtype)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
||||
|
||||
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_cond_mask = prepend_cond_mask.bool()
|
||||
|
||||
|
||||
# CFG dropout
|
||||
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
|
||||
if cross_attn_cond is not None:
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||
|
||||
if prepend_cond is not None:
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||
|
||||
if add_cond is not None:
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
|
||||
add_cond = torch.where(dropout_mask, null_embed, add_cond)
|
||||
|
||||
if sync_cond is not None:
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
|
||||
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
|
||||
|
||||
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
|
||||
# Classifier-free guidance
|
||||
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
||||
batch_inputs = torch.cat([x, x], dim=0)
|
||||
batch_timestep = torch.cat([t, t], dim=0)
|
||||
if global_embed is not None and global_embed.shape[0] == bsz:
|
||||
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
||||
elif global_embed is not None:
|
||||
batch_global_cond = global_embed
|
||||
else:
|
||||
batch_global_cond = None
|
||||
|
||||
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
|
||||
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
||||
elif input_concat_cond is not None:
|
||||
batch_input_concat_cond = input_concat_cond
|
||||
else:
|
||||
batch_input_concat_cond = None
|
||||
|
||||
batch_cond = None
|
||||
batch_cond_masks = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
|
||||
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
||||
if negative_cross_attn_cond is not None:
|
||||
|
||||
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
||||
if negative_cross_attn_mask is not None:
|
||||
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
||||
|
||||
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
||||
|
||||
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
||||
|
||||
else:
|
||||
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
||||
elif cross_attn_cond is not None:
|
||||
batch_cond = cross_attn_cond
|
||||
else:
|
||||
batch_cond = None
|
||||
|
||||
batch_prepend_cond = None
|
||||
batch_prepend_cond_mask = None
|
||||
|
||||
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
|
||||
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||
elif prepend_cond is not None:
|
||||
batch_prepend_cond = prepend_cond
|
||||
else:
|
||||
batch_prepend_cond = None
|
||||
|
||||
batch_add_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if add_cond is not None and add_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
|
||||
|
||||
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
|
||||
elif add_cond is not None:
|
||||
batch_add_cond = add_cond
|
||||
else:
|
||||
batch_add_cond = None
|
||||
|
||||
batch_sync_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if sync_cond is not None and sync_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
|
||||
|
||||
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
|
||||
elif sync_cond is not None:
|
||||
batch_sync_cond = sync_cond
|
||||
else:
|
||||
batch_sync_cond = None
|
||||
|
||||
if mask is not None:
|
||||
batch_masks = torch.cat([mask, mask], dim=0)
|
||||
else:
|
||||
batch_masks = None
|
||||
|
||||
batch_output = self._forward(
|
||||
batch_inputs,
|
||||
batch_timestep,
|
||||
cross_attn_cond=batch_cond,
|
||||
cross_attn_cond_mask=batch_cond_masks,
|
||||
mask = batch_masks,
|
||||
input_concat_cond=batch_input_concat_cond,
|
||||
global_embed = batch_global_cond,
|
||||
prepend_cond = batch_prepend_cond,
|
||||
prepend_cond_mask = batch_prepend_cond_mask,
|
||||
add_cond = batch_add_cond,
|
||||
sync_cond = batch_sync_cond,
|
||||
return_info = return_info,
|
||||
**kwargs)
|
||||
|
||||
if return_info:
|
||||
batch_output, info = batch_output
|
||||
|
||||
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
||||
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
||||
|
||||
# CFG Rescale
|
||||
if scale_phi != 0.0:
|
||||
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
||||
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
||||
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
||||
else:
|
||||
output = cfg_output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
else:
|
||||
return self._forward(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_cond_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
global_embed=global_embed,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
add_cond=add_cond,
|
||||
sync_cond=sync_cond,
|
||||
mask=mask,
|
||||
return_info=return_info,
|
||||
**kwargs
|
||||
)
|
||||
@@ -1,275 +0,0 @@
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from .blocks import AdaRMSNorm
|
||||
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||
from .utils import checkpoint
|
||||
|
||||
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||
class ContinuousLocalTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
causal = False,
|
||||
local_attn_window_size = 64,
|
||||
heads = 8,
|
||||
ff_mult = 2,
|
||||
cond_dim = 0,
|
||||
cross_attn_cond_dim = 0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_head = dim//heads
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
||||
|
||||
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
||||
|
||||
self.local_attn_window_size = local_attn_window_size
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
self.cross_attn_cond_dim = cross_attn_cond_dim
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
||||
|
||||
for _ in range(depth):
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
causal=causal,
|
||||
zero_init_output=True,
|
||||
natten_kernel_size=local_attn_window_size,
|
||||
),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
dim_context = cross_attn_cond_dim,
|
||||
zero_init_output=True
|
||||
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
||||
]))
|
||||
|
||||
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
if prepend_cond is not None:
|
||||
x = torch.cat([prepend_cond, x], dim=1)
|
||||
|
||||
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
|
||||
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
||||
|
||||
residual = x
|
||||
if cond is not None:
|
||||
x = checkpoint(attn_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(attn_norm, x)
|
||||
|
||||
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
||||
|
||||
residual = x
|
||||
|
||||
if cond is not None:
|
||||
x = checkpoint(ff_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(ff_norm, x)
|
||||
|
||||
x = checkpoint(ff, x) + residual
|
||||
|
||||
return checkpoint(self.project_out, x)
|
||||
|
||||
class TransformerDownsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim = 768,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
downsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size=local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
# Trade sequence length for channels
|
||||
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
||||
|
||||
# Project back to embed dim
|
||||
x = checkpoint(self.project_down, x)
|
||||
|
||||
return x
|
||||
|
||||
class TransformerUpsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
upsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.upsample_ratio = upsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Project to embed dim
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Project to increase channel dim
|
||||
x = checkpoint(self.project_up, x)
|
||||
|
||||
# Trade channels for sequence length
|
||||
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [96, 192, 384, 768],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerDownsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
downsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [768, 384, 192, 96],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerUpsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
upsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
return x
|
||||
@@ -1 +0,0 @@
|
||||
# mmmodules package
|
||||
@@ -1 +0,0 @@
|
||||
# mmmodules.model package
|
||||
@@ -1,95 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ChannelLastConv1d(nn.Conv1d):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/sd3-ref
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class ConvMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w2 = ChannelLastConv1d(hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w3 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
@@ -1,393 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from scipy.optimize import fmin
|
||||
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
||||
|
||||
class PQMF(nn.Module):
|
||||
"""
|
||||
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
||||
Uses polyphase representation which is computationally more efficient for real-time.
|
||||
|
||||
Parameters:
|
||||
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
||||
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
||||
"""
|
||||
|
||||
def __init__(self, attenuation, num_bands):
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# Ensure num_bands is a power of 2
|
||||
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
||||
assert is_power_of_2, "'num_bands' must be a power of 2."
|
||||
|
||||
# Create the prototype filter
|
||||
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
||||
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
||||
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
||||
|
||||
# Register filters and settings
|
||||
self.register_buffer("filter_bank", padded_filter_bank)
|
||||
self.register_buffer("prototype", prototype_filter)
|
||||
self.num_bands = num_bands
|
||||
|
||||
def forward(self, signal):
|
||||
"""Decompose the signal into multiple frequency bands."""
|
||||
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
||||
signal = prepare_signal_dimensions(signal)
|
||||
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
||||
signal = pad_signal(signal, self.num_bands)
|
||||
# run it
|
||||
signal = polyphase_analysis(signal, self.filter_bank)
|
||||
return apply_alias_cancellation(signal)
|
||||
|
||||
def inverse(self, bands):
|
||||
"""Reconstruct the original signal from the frequency bands."""
|
||||
bands = apply_alias_cancellation(bands)
|
||||
return polyphase_synthesis(bands, self.filter_bank)
|
||||
|
||||
|
||||
def prepare_signal_dimensions(signal):
|
||||
"""
|
||||
Rearrange signal into Batch x Channels x Length.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor or numpy.ndarray
|
||||
The input signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Preprocessed signal tensor.
|
||||
"""
|
||||
# Convert numpy to torch tensor
|
||||
if isinstance(signal, np.ndarray):
|
||||
signal = torch.from_numpy(signal)
|
||||
|
||||
# Ensure tensor
|
||||
if not isinstance(signal, torch.Tensor):
|
||||
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
||||
|
||||
# Modify dimension of signal to Batch x Channels x Length
|
||||
if signal.dim() == 1:
|
||||
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
||||
signal = signal.unsqueeze(0).unsqueeze(0)
|
||||
elif signal.dim() == 2:
|
||||
# This is a multi-channel signal (e.g. stereo)
|
||||
# Rearrange so that larger dimension (Length) is last
|
||||
if signal.shape[0] > signal.shape[1]:
|
||||
signal = signal.T
|
||||
# Unsqueeze to 1 x Channels x Length
|
||||
signal = signal.unsqueeze(0)
|
||||
return signal
|
||||
|
||||
def pad_signal(signal, num_bands):
|
||||
"""
|
||||
Pads the signal to make its length divisible by the given number of bands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
The input signal tensor, where the last dimension represents the signal length.
|
||||
|
||||
num_bands : int
|
||||
The number of bands by which the signal length should be divisible.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The padded signal tensor. If the original signal length was already divisible
|
||||
by num_bands, returns the original signal unchanged.
|
||||
"""
|
||||
remainder = signal.shape[-1] % num_bands
|
||||
if remainder > 0:
|
||||
padding_size = num_bands - remainder
|
||||
signal = nn.functional.pad(signal, (0, padding_size))
|
||||
return signal
|
||||
|
||||
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
||||
"""
|
||||
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prototype_filter : torch.Tensor
|
||||
The prototype filter used as the basis for modulation.
|
||||
num_bands : int
|
||||
The number of desired subbands or filters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A bank of cosine modulated filters.
|
||||
"""
|
||||
|
||||
# Initialize indices for modulation.
|
||||
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
||||
|
||||
# Calculate the length of the prototype filter.
|
||||
filter_length = prototype_filter.shape[-1]
|
||||
|
||||
# Generate symmetric time indices centered around zero.
|
||||
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
||||
|
||||
# Calculate phase offsets to ensure orthogonality between subbands.
|
||||
phase_offsets = (-1)**subband_indices * np.pi / 4
|
||||
|
||||
# Compute the cosine modulation function.
|
||||
modulation = torch.cos(
|
||||
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
||||
)
|
||||
|
||||
# Apply modulation to the prototype filter.
|
||||
modulated_filters = 2 * prototype_filter * modulation
|
||||
|
||||
return modulated_filters
|
||||
|
||||
|
||||
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
||||
"""
|
||||
Design a lowpass filter using the Kaiser window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
The angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
The desired stopband attenuation in decibels (dB).
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The designed lowpass filter coefficients.
|
||||
"""
|
||||
|
||||
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
||||
|
||||
# Ensure the estimated length is odd.
|
||||
estimated_length = 2 * (estimated_length // 2) + 1
|
||||
|
||||
if filter_length is None:
|
||||
filter_length = estimated_length
|
||||
|
||||
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
||||
|
||||
|
||||
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
||||
"""
|
||||
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
Angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
Desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The computed objective (loss) value for the given filter specs.
|
||||
"""
|
||||
|
||||
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
||||
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
||||
|
||||
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
||||
|
||||
|
||||
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
||||
"""
|
||||
Design the optimal prototype filter for a multiband system given the desired specs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attenuation : float
|
||||
The desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The optimal prototype filter coefficients.
|
||||
"""
|
||||
|
||||
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
||||
1 / num_bands, disp=0)[0]
|
||||
|
||||
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
||||
return torch.tensor(prototype_filter, dtype=torch.float32)
|
||||
|
||||
def pad_to_nearest_power_of_two(x):
|
||||
"""
|
||||
Pads the input tensor 'x' on both sides such that its last dimension
|
||||
becomes the nearest larger power of two.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor to be padded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The padded tensor.
|
||||
"""
|
||||
current_length = x.shape[-1]
|
||||
target_length = 2**math.ceil(math.log2(current_length))
|
||||
|
||||
total_padding = target_length - current_length
|
||||
left_padding = total_padding // 2
|
||||
right_padding = total_padding - left_padding
|
||||
|
||||
return nn.functional.pad(x, (left_padding, right_padding))
|
||||
|
||||
def apply_alias_cancellation(x):
|
||||
"""
|
||||
Applies alias cancellation by inverting the sign of every
|
||||
second element of every second row, starting from the second
|
||||
row's first element in a tensor.
|
||||
|
||||
This operation helps ensure that the aliasing introduced in
|
||||
each band during the decomposition will be counteracted during
|
||||
the reconstruction.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Tensor with specific elements' sign inverted for alias cancellation.
|
||||
"""
|
||||
|
||||
# Create a mask of the same shape as 'x', initialized with all ones
|
||||
mask = torch.ones_like(x)
|
||||
|
||||
# Update specific elements in the mask to -1 to perform inversion
|
||||
mask[..., 1::2, ::2] = -1
|
||||
|
||||
# Apply the mask to the input tensor 'x'
|
||||
return x * mask
|
||||
|
||||
def ensure_odd_length(tensor):
|
||||
"""
|
||||
Pads the last dimension of a tensor to ensure its size is odd.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor whose last dimension might need padding.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The original tensor if its last dimension was already odd,
|
||||
or the padded tensor with an odd-sized last dimension.
|
||||
"""
|
||||
|
||||
last_dim_size = tensor.shape[-1]
|
||||
|
||||
if last_dim_size % 2 == 0:
|
||||
tensor = nn.functional.pad(tensor, (0, 1))
|
||||
|
||||
return tensor
|
||||
|
||||
def polyphase_analysis(signal, filter_bank):
|
||||
"""
|
||||
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
signal : torch.Tensor
|
||||
Input signal tensor with shape (Batch x Channels x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Filter bank tensor with shape (Bands x Length).
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange signal for polyphase processing.
|
||||
# Also combine Batch x Channel into one dimension for now.
|
||||
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
||||
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
||||
|
||||
# Rearrange the filter bank for matching signal shape
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
||||
|
||||
# Apply convolution with appropriate padding to maintain spatial dimensions
|
||||
padding = filter_bank.shape[-1] // 2
|
||||
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
||||
|
||||
# Truncate the last dimension post-convolution to adjust the output shape
|
||||
filtered_signal = filtered_signal[..., :-1]
|
||||
# Rearrange the first dimension back into Batch x Channels
|
||||
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
def polyphase_synthesis(signal, filter_bank):
|
||||
"""
|
||||
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Analysis filter bank (shape: Bands x Length).
|
||||
|
||||
should_rearrange : bool, optional
|
||||
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Reconstructed signal (shape: Batch x Channels X Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange the filter bank
|
||||
filter_bank = filter_bank.flip(-1)
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
||||
|
||||
# Combine Batch x Channels into one dimension for now.
|
||||
signal = rearrange(signal, "b c n t -> (b c) n t")
|
||||
|
||||
# Apply convolution with appropriate padding
|
||||
padding_amount = filter_bank.shape[-1] // 2 + 1
|
||||
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
||||
|
||||
# Scale the result
|
||||
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
||||
|
||||
# Reorganize the output and truncate
|
||||
reconstructed_signal = reconstructed_signal.flip(1)
|
||||
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
||||
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
||||
|
||||
return reconstructed_signal
|
||||
@@ -1,239 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class Pretransform(nn.Module):
|
||||
def __init__(self, enable_grad, io_channels, is_discrete):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
self.io_channels = io_channels
|
||||
self.encoded_channels = None
|
||||
self.downsampling_ratio = None
|
||||
|
||||
self.enable_grad = enable_grad
|
||||
|
||||
def encode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, z):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
class AutoencoderPretransform(Pretransform):
|
||||
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
||||
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
||||
self.model = model
|
||||
self.model.requires_grad_(False).eval()
|
||||
self.scale=scale
|
||||
self.downsampling_ratio = model.downsampling_ratio
|
||||
self.io_channels = model.io_channels
|
||||
self.sample_rate = model.sample_rate
|
||||
|
||||
self.model_half = model_half
|
||||
self.iterate_batch = iterate_batch
|
||||
|
||||
self.encoded_channels = model.latent_dim
|
||||
|
||||
self.chunked = chunked
|
||||
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
|
||||
if self.model_half:
|
||||
self.model.half()
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
if self.model_half:
|
||||
x = x.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
encoded = encoded.float()
|
||||
|
||||
return encoded / self.scale
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = z * self.scale
|
||||
|
||||
if self.model_half:
|
||||
z = z.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
decoded = decoded.float()
|
||||
|
||||
return decoded
|
||||
|
||||
def tokenize(self, x, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
||||
|
||||
_, info = self.model.encode(x, return_info = True, **kwargs)
|
||||
|
||||
return info[self.model.bottleneck.tokens_id]
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
||||
|
||||
return self.model.decode_tokens(tokens, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
class PQMFPretransform(Pretransform):
|
||||
def __init__(self, attenuation=100, num_bands=16):
|
||||
# TODO: Fix PQMF to take in in-channels
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
||||
from .pqmf import PQMF
|
||||
self.pqmf = PQMF(attenuation, num_bands)
|
||||
|
||||
|
||||
def encode(self, x):
|
||||
# x is (Batch x Channels x Time)
|
||||
x = self.pqmf.forward(x)
|
||||
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
||||
# but Pretransform needs Batch x Channels x Time
|
||||
# so concatenate channels and bands into one axis
|
||||
return rearrange(x, "b c n t -> b (c n) t")
|
||||
|
||||
def decode(self, x):
|
||||
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
||||
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
||||
# returns (Batch x Channels x Time)
|
||||
return self.pqmf.inverse(x)
|
||||
|
||||
class PretrainedDACPretransform(Pretransform):
|
||||
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
import dac
|
||||
|
||||
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
||||
|
||||
self.model = dac.DAC.load(model_path)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
if model_type == "44khz":
|
||||
self.downsampling_ratio = 512
|
||||
else:
|
||||
self.downsampling_ratio = 320
|
||||
|
||||
self.io_channels = 1
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.chunked = chunked
|
||||
|
||||
self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.n_codebooks
|
||||
|
||||
self.codebook_size = self.model.codebook_size
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
latents = self.model.encoder(x)
|
||||
|
||||
if self.quantize_on_decode:
|
||||
output = latents
|
||||
else:
|
||||
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
output = z
|
||||
|
||||
if self.scale != 1.0:
|
||||
output = output / self.scale
|
||||
|
||||
return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
if self.scale != 1.0:
|
||||
z = z * self.scale
|
||||
|
||||
if self.quantize_on_decode:
|
||||
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
return self.model.encode(x)[1]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
latents = self.model.quantizer.from_codes(tokens)
|
||||
return self.model.decode(latents)
|
||||
|
||||
class AudiocraftCompressionPretransform(Pretransform):
|
||||
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
try:
|
||||
from audiocraft.models import CompressionModel
|
||||
except ImportError:
|
||||
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
||||
|
||||
self.model = CompressionModel.get_pretrained(model_type)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
||||
|
||||
self.sample_rate = self.model.sample_rate
|
||||
|
||||
self.io_channels = self.model.channels
|
||||
|
||||
self.scale = scale
|
||||
|
||||
#self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.num_codebooks
|
||||
|
||||
self.codebook_size = self.model.cardinality
|
||||
|
||||
self.model.to(torch.float16).eval().requires_grad_(False)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous encoding"
|
||||
|
||||
# latents = self.model.encoder(x)
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# output = latents
|
||||
# else:
|
||||
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
# output = z
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# output = output / self.scale
|
||||
|
||||
# return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous decoding"
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# z = z * self.scale
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
# return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.encode(x.to(torch.float16))[0]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.decode(tokens)
|
||||
@@ -1,989 +0,0 @@
|
||||
from functools import reduce, partial
|
||||
from packaging import version
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from torch.cuda.amp import autocast
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from typing import Callable, Literal
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
flash_attn_kvpacked_func = None
|
||||
flash_attn_func = None
|
||||
|
||||
from .utils import compile, checkpoint
|
||||
try:
|
||||
import natten
|
||||
except ImportError:
|
||||
natten = None
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
||||
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
||||
|
||||
def create_causal_mask(i, j, device):
|
||||
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||
|
||||
def or_reduce(masks):
|
||||
head, *body = masks
|
||||
for rest in body:
|
||||
head = head | rest
|
||||
return head
|
||||
|
||||
# positional embeddings
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||
|
||||
half_dim = dim // 2
|
||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||
inv_freq = theta ** -freq_seq
|
||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = pos - seq_start_pos[..., None]
|
||||
|
||||
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
use_xpos = False,
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer('scale', None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer('scale', scale)
|
||||
|
||||
def forward_from_seq_len(self, seq_len):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = torch.arange(seq_len, device = device)
|
||||
return self.forward(t)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def forward(self, t):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = t.to(torch.float32)
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||
x1, x2 = x.unbind(dim = -2)
|
||||
return torch.cat((-x2, x1), dim = -1)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||
out_dtype = t.dtype
|
||||
|
||||
# cast to float32 if necessary for numerical stability
|
||||
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||
freqs = freqs[-seq_len:, :]
|
||||
|
||||
if t.ndim == 4 and freqs.ndim == 3:
|
||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||
|
||||
# partial rotary embeddings, Wang et al. GPT-J
|
||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
|
||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||
|
||||
return torch.cat((t, t_unrotated), dim = -1)
|
||||
|
||||
# norms
|
||||
class DynamicTanh(nn.Module):
|
||||
def __init__(self, dim, init_alpha=10.0):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.tanh(self.alpha * x)
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
class RunningInstanceNorm(nn.Module):
|
||||
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
||||
super().__init__()
|
||||
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
||||
self.register_buffer("running_std", torch.ones(1,1,dim))
|
||||
self.saturate = saturate
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.dim = dim
|
||||
self.trainable_gain = trainable_gain
|
||||
if self.trainable_gain:
|
||||
self.gain = nn.Parameter(torch.ones(1))
|
||||
|
||||
def _update_stats(self, x):
|
||||
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
||||
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
self._update_stats(x)
|
||||
x = (x - self.running_mean) / self.running_std
|
||||
if self.saturate:
|
||||
x = torch.asinh(x)
|
||||
if self.trainable_gain:
|
||||
x = x * self.gain
|
||||
return x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
||||
"""
|
||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("gamma", torch.ones(dim))
|
||||
else:
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
self.eps = eps
|
||||
|
||||
self.force_fp32 = force_fp32
|
||||
|
||||
def forward(self, x):
|
||||
if not self.force_fp32:
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
||||
else:
|
||||
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
||||
return output.to(x.dtype)
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_val = 1e-5):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.full([dim], init_val))
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
activation: Callable,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
else:
|
||||
x = self.proj(x)
|
||||
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * self.act(gate)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out = None,
|
||||
mult = 4,
|
||||
no_bias = False,
|
||||
glu = True,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
zero_init_output = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
# Default to SwiGLU
|
||||
|
||||
activation = nn.SiLU()
|
||||
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
|
||||
if glu:
|
||||
linear_in = GLU(dim, inner_dim, activation)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
||||
|
||||
# init last linear layer to 0
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(linear_out.weight)
|
||||
if not no_bias:
|
||||
nn.init.zeros_(linear_out.bias)
|
||||
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
dim_context = None,
|
||||
causal = False,
|
||||
zero_init_output=True,
|
||||
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
||||
differential = False,
|
||||
feat_scale = False
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.differential = differential
|
||||
|
||||
dim_kv = dim_context if dim_context is not None else dim
|
||||
|
||||
self.num_heads = dim // dim_heads
|
||||
self.kv_heads = dim_kv // dim_heads
|
||||
|
||||
if dim_context is not None:
|
||||
if differential:
|
||||
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
||||
else:
|
||||
self.to_q = nn.Linear(dim, dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
||||
else:
|
||||
if differential:
|
||||
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
||||
else:
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
||||
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
||||
|
||||
self.qk_norm = qk_norm
|
||||
if self.qk_norm == "ln":
|
||||
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
elif self.qk_norm == 'rns':
|
||||
self.q_norm = nn.RMSNorm(dim_heads)
|
||||
self.k_norm = nn.RMSNorm(dim_heads)
|
||||
elif self.qk_norm == 'dyt':
|
||||
self.q_norm = DynamicTanh(dim_heads)
|
||||
self.k_norm = DynamicTanh(dim_heads)
|
||||
|
||||
self.sdp_kwargs = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
enable_mem_efficient = True
|
||||
)
|
||||
|
||||
self.feat_scale = feat_scale
|
||||
|
||||
if self.feat_scale:
|
||||
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
self.causal = causal
|
||||
|
||||
@compile
|
||||
def apply_qk_layernorm(self, q, k):
|
||||
q_type = q.dtype
|
||||
k_type = k.dtype
|
||||
q = self.q_norm(q).to(q_type)
|
||||
k = self.k_norm(k).to(k_type)
|
||||
return q, k
|
||||
|
||||
|
||||
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
||||
|
||||
if self.num_heads != self.kv_heads:
|
||||
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
||||
heads_per_kv_head = self.num_heads // self.kv_heads
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
flash_attn_available = HAS_FLASH_ATTN
|
||||
|
||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||
flex_attention_block_mask = None
|
||||
flex_attention_score_mod = None
|
||||
|
||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention is not available in this build. "
|
||||
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
|
||||
)
|
||||
elif flash_attn_available:
|
||||
fa_dtype_in = q.dtype
|
||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||
|
||||
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
||||
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
||||
|
||||
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
||||
|
||||
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
||||
else:
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
||||
return out
|
||||
|
||||
|
||||
#@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None,
|
||||
flex_attention_block_mask = None,
|
||||
flex_attention_score_mod = None,
|
||||
flash_attn_sliding_window = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
kv_input = context if has_context else x
|
||||
|
||||
if hasattr(self, 'to_q'):
|
||||
# Use separate linear projections for q and k/v
|
||||
if self.differential:
|
||||
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
||||
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
||||
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||
else:
|
||||
# Use fused linear projection
|
||||
if self.differential:
|
||||
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
||||
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# Normalize q and k for cosine sim attention
|
||||
if self.qk_norm == "l2":
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
elif self.qk_norm != "none":
|
||||
q, k = self.apply_qk_layernorm(q, k)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
freqs, _ = rotary_pos_emb
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
freqs = freqs.to(torch.float32)
|
||||
if q.shape[-2] >= k.shape[-2]:
|
||||
ratio = q.shape[-2] / k.shape[-2]
|
||||
q_freqs, k_freqs = freqs, ratio * freqs
|
||||
else:
|
||||
ratio = k.shape[-2] / q.shape[-2]
|
||||
q_freqs, k_freqs = ratio * freqs, freqs
|
||||
q = apply_rotary_pos_emb(q, q_freqs)
|
||||
k = apply_rotary_pos_emb(k, k_freqs)
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim = 1)
|
||||
k, k_diff = k.unbind(dim = 1)
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(out, ' b h n d -> b n (h d)')
|
||||
|
||||
# Communicate between heads
|
||||
|
||||
# with autocast(enabled = False):
|
||||
# out_dtype = out.dtype
|
||||
# out = out.to(torch.float32)
|
||||
# out = self.to_out(out).to(out_dtype)
|
||||
out = self.to_out(out)
|
||||
|
||||
if self.feat_scale:
|
||||
out_dc = out.mean(dim=-2, keepdim=True)
|
||||
out_hf = out - out_dc
|
||||
|
||||
# Selectively modulate DC and high frequency components
|
||||
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
||||
|
||||
return out
|
||||
|
||||
class ConformerModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
norm_kwargs = {},
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
self.glu = GLU(dim, dim, nn.SiLU())
|
||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||
self.swish = nn.SiLU()
|
||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
#@compile
|
||||
def forward(self, x):
|
||||
x = self.in_norm(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.glu(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.depthwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.mid_norm(x)
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
cross_attend = False,
|
||||
dim_context = None,
|
||||
global_cond_dim = None,
|
||||
causal = False,
|
||||
zero_init_branch_outputs = True,
|
||||
conformer = False,
|
||||
layer_ix = -1,
|
||||
remove_norms = False,
|
||||
add_rope = False,
|
||||
layer_scale = False,
|
||||
use_sync_block_film = False,
|
||||
attn_kwargs = {},
|
||||
ff_kwargs = {},
|
||||
norm_kwargs = {}
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = min(dim_heads,dim)
|
||||
self.cross_attend = cross_attend
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
if layer_scale and zero_init_branch_outputs:
|
||||
zero_init_branch_outputs = False
|
||||
|
||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
|
||||
self.add_rope = add_rope
|
||||
|
||||
self.self_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
if cross_attend:
|
||||
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.cross_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
dim_context=dim_context,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
||||
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.layer_ix = layer_ix
|
||||
|
||||
self.conformer = None
|
||||
if conformer:
|
||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
||||
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.global_cond_dim = global_cond_dim
|
||||
if global_cond_dim is not None:
|
||||
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
||||
|
||||
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
||||
|
||||
if use_sync_block_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
|
||||
@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
global_cond=None,
|
||||
rotary_pos_emb = None,
|
||||
self_attention_block_mask = None,
|
||||
self_attention_score_mod = None,
|
||||
cross_attention_block_mask = None,
|
||||
cross_attention_score_mod = None,
|
||||
self_attention_flash_sliding_window = None,
|
||||
cross_attention_flash_sliding_window = None,
|
||||
sync_cond = None,
|
||||
prepend_length=0
|
||||
):
|
||||
if rotary_pos_emb is None and self.add_rope:
|
||||
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
||||
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
||||
else:
|
||||
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
||||
|
||||
# self-attention with adaLN
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = self.self_attn_scale(x)
|
||||
x = x + residual
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
# feedforward with adaLN
|
||||
residual = x
|
||||
x = self.ff_norm(x)
|
||||
x = x * (1 + scale_ff) + shift_ff
|
||||
x = self.ff(x)
|
||||
x = x * torch.sigmoid(1 - gate_ff)
|
||||
x = self.ff_scale(x)
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
prepend_part = x[:, :prepend_length, :]
|
||||
audio_part = x[:, prepend_length:, :]
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
modulated_audio_part = audio_part * (1 + scale) + shift
|
||||
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
||||
|
||||
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
||||
return x
|
||||
|
||||
class ContinuousTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
*,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
dim_heads = 64,
|
||||
cross_attend=False,
|
||||
cond_token_dim=None,
|
||||
pre_cross_attn_ix=-1,
|
||||
final_cross_attn_ix=-1,
|
||||
global_cond_dim=None,
|
||||
causal=False,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_outputs=True,
|
||||
conformer=False,
|
||||
use_sinusoidal_emb=False,
|
||||
use_abs_pos_emb=False,
|
||||
abs_pos_emb_max_length=10000,
|
||||
num_memory_tokens=0,
|
||||
sliding_window=None,
|
||||
use_mlp=False,
|
||||
use_add_norm=False,
|
||||
use_gated=False,
|
||||
use_final_layer=False,
|
||||
use_zeros=False,
|
||||
use_conv=False,
|
||||
use_fusion_mlp=False,
|
||||
use_film=False,
|
||||
use_sync_film=False,
|
||||
use_sync_gated=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.causal = causal
|
||||
self.layers = nn.ModuleList([])
|
||||
if use_mlp:
|
||||
self.project_in = nn.Sequential(
|
||||
nn.Linear(dim_in, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim, bias=False)
|
||||
)
|
||||
else:
|
||||
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
||||
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
||||
self.video_temporal_conv = None
|
||||
self.audio_temporal_conv = None
|
||||
self.fusion_mlp = None
|
||||
if use_conv:
|
||||
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
if use_fusion_mlp:
|
||||
self.fusion_mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||
if use_sinusoidal_emb:
|
||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
||||
|
||||
self.adaLN_modulation = None
|
||||
if global_cond_dim is not None:
|
||||
if use_final_layer:
|
||||
self.norm_final = LayerNorm(dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
dim, 2 * dim, bias=True
|
||||
),
|
||||
)
|
||||
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.project_out.weight, 0)
|
||||
self.global_cond_embedder = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6)
|
||||
)
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
||||
|
||||
self.final_cross_attn_ix = final_cross_attn_ix
|
||||
self.use_gated = use_gated
|
||||
self.use_film = use_film
|
||||
self.use_add_norm = use_add_norm
|
||||
if self.use_add_norm:
|
||||
self.add_norm = nn.LayerNorm(dim)
|
||||
if use_gated:
|
||||
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
||||
|
||||
if use_film:
|
||||
self.film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.film_generator = None
|
||||
|
||||
if use_sync_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.sync_film_generator = None
|
||||
if use_sync_gated:
|
||||
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
else:
|
||||
self.sync_gate = None
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
for i in range(depth):
|
||||
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
||||
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
cross_attend = should_cross_attend,
|
||||
dim_context = cond_token_dim,
|
||||
global_cond_dim = global_cond_dim,
|
||||
causal = causal,
|
||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||
conformer=conformer,
|
||||
layer_ix=i,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
prepend_embeds = None,
|
||||
prepend_mask = None,
|
||||
add_cond = None,
|
||||
sync_cond = None,
|
||||
global_cond = None,
|
||||
return_info = False,
|
||||
use_checkpointing = True,
|
||||
exit_layer_ix = None,
|
||||
video_dropout_prob = 0.0,
|
||||
**kwargs
|
||||
):
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
|
||||
prepend_length = 0
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
x = self.project_in(x)
|
||||
if add_cond is not None:
|
||||
if self.use_gated:
|
||||
gate = torch.sigmoid(self.gate)
|
||||
x = x + gate * add_cond
|
||||
elif self.use_film:
|
||||
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
else:
|
||||
x = x + add_cond
|
||||
|
||||
if self.use_add_norm:
|
||||
x = self.add_norm(x)
|
||||
if self.fusion_mlp is not None:
|
||||
x = self.fusion_mlp(x)
|
||||
|
||||
if sync_cond is not None:
|
||||
# Resample sync_cond to match audio sequence length if needed
|
||||
if sync_cond.shape[1] != x.shape[1]:
|
||||
sync_cond = torch.nn.functional.interpolate(
|
||||
sync_cond.transpose(1, 2), size=x.shape[1],
|
||||
mode='linear', align_corners=False,
|
||||
).transpose(1, 2)
|
||||
if self.sync_film_generator is not None:
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
elif self.sync_gate is not None:
|
||||
gate_value = torch.sigmoid(self.sync_gate)
|
||||
x = x + gate_value * sync_cond
|
||||
# else:
|
||||
# x = x + sync_cond
|
||||
|
||||
if prepend_embeds is not None:
|
||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||
|
||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||
|
||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||
|
||||
if self.num_memory_tokens > 0:
|
||||
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
||||
x = torch.cat((memory_tokens, x), dim=1)
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
if global_cond is not None and self.global_cond_embedder is not None:
|
||||
global_cond_embed = self.global_cond_embedder(global_cond)
|
||||
else:
|
||||
global_cond_embed = global_cond
|
||||
# Iterate over the transformer layers
|
||||
for layer_ix, layer in enumerate(self.layers):
|
||||
if use_checkpointing:
|
||||
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
|
||||
if return_info:
|
||||
info["hidden_states"].append(x)
|
||||
|
||||
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
if global_cond is not None and self.adaLN_modulation is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
global_cond = global_cond.unsqueeze(1)
|
||||
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.project_out(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
@@ -1,180 +0,0 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
|
||||
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def load_ckpt_state_dict(ckpt_path, prefix=None):
|
||||
if ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
# 过滤特定前缀的state_dict
|
||||
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
def remove_weight_norm_from_model(model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight"):
|
||||
remove_weight_norm(module)
|
||||
|
||||
return model
|
||||
|
||||
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
||||
# License can be found in LICENSES/LICENSE_META.txt
|
||||
|
||||
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor containing probabilities.
|
||||
num_samples (int): Number of samples to draw.
|
||||
replacement (bool): Whether to draw with replacement or not.
|
||||
Keywords args:
|
||||
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||
Returns:
|
||||
torch.Tensor: Last dimension contains num_samples indices
|
||||
sampled from the multinomial probability distribution
|
||||
located in the last dimension of tensor input.
|
||||
"""
|
||||
|
||||
if num_samples == 1:
|
||||
q = torch.empty_like(input).exponential_(1, generator=generator)
|
||||
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
||||
|
||||
input_ = input.reshape(-1, input.shape[-1])
|
||||
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||
return output
|
||||
|
||||
|
||||
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
k (int): The k in “top-k”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||
min_value_top_k = top_k_value[..., [-1]]
|
||||
probs *= (probs >= min_value_top_k).float()
|
||||
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs, num_samples=1)
|
||||
return next_token
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
p (int): The p in “top-p”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort *= (~mask).float()
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
def next_power_of_two(n):
|
||||
return 2 ** (n - 1).bit_length()
|
||||
|
||||
def next_multiple_of_64(n):
|
||||
return ((n + 63) // 64) * 64
|
||||
|
||||
|
||||
# mask construction helpers
|
||||
|
||||
def mask_from_start_end_indices(
|
||||
seq_len: int,
|
||||
start: Tensor,
|
||||
end: Tensor
|
||||
):
|
||||
assert start.shape == end.shape
|
||||
device = start.device
|
||||
|
||||
seq = torch.arange(seq_len, device = device, dtype = torch.long)
|
||||
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
|
||||
seq = seq.expand(*start.shape, seq_len)
|
||||
|
||||
mask = seq >= start[..., None].long()
|
||||
mask &= seq < end[..., None].long()
|
||||
return mask
|
||||
|
||||
def mask_from_frac_lengths(
|
||||
seq_len: int,
|
||||
frac_lengths: Tensor
|
||||
):
|
||||
device = frac_lengths.device
|
||||
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
|
||||
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
|
||||
start = (max_start * rand).clamp(min = 0)
|
||||
end = start + lengths
|
||||
|
||||
return mask_from_start_end_indices(seq_len, start, end)
|
||||
|
||||
def _build_spline(video_feat, video_t, target_t):
|
||||
# 三次样条插值核心实现
|
||||
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
|
||||
spline = NaturalCubicSpline(coeffs)
|
||||
return spline.evaluate(target_t).permute(0,2,1)
|
||||
|
||||
def resample(video_feat, audio_latent):
|
||||
"""
|
||||
9s
|
||||
video_feat: [B, 72, D]
|
||||
audio_latent: [B, D', 194] or int
|
||||
"""
|
||||
B, Tv, D = video_feat.shape
|
||||
|
||||
if isinstance(audio_latent, torch.Tensor):
|
||||
# audio_latent is a tensor
|
||||
if audio_latent.shape[1] != 64:
|
||||
Ta = audio_latent.shape[1]
|
||||
else:
|
||||
Ta = audio_latent.shape[2]
|
||||
elif isinstance(audio_latent, int):
|
||||
# audio_latent is an int
|
||||
Ta = audio_latent
|
||||
else:
|
||||
raise TypeError("audio_latent must be either a tensor or an int")
|
||||
|
||||
# 构建时间戳 (关键改进点)
|
||||
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
|
||||
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
|
||||
|
||||
# 三维化处理 (Batch, Feature, Time)
|
||||
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
|
||||
|
||||
# 三次样条插值
|
||||
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
||||
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
import os
|
||||
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
|
||||
if enable_torch_compile:
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
return function
|
||||
@@ -1,11 +1,5 @@
|
||||
einops>=0.7.0
|
||||
einops-exts
|
||||
safetensors
|
||||
huggingface_hub
|
||||
transformers>=4.52.3
|
||||
k-diffusion>=0.1.1
|
||||
alias-free-torch
|
||||
descript-audio-codec
|
||||
vector-quantize-pytorch
|
||||
scipy
|
||||
tqdm
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
name: prismaudio-extract
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- pip
|
||||
- ffmpeg<7
|
||||
- pip:
|
||||
- torch>=2.6.0
|
||||
- torchaudio>=2.6.0
|
||||
- torchvision>=0.21.0
|
||||
- tensorflow-cpu==2.15.0
|
||||
- jax
|
||||
- jaxlib
|
||||
- transformers>=4.52.3
|
||||
- decord
|
||||
- einops>=0.7.0
|
||||
- numpy
|
||||
- mediapy
|
||||
- git+https://github.com/google-deepmind/videoprism.git
|
||||
@@ -1,168 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone PrismAudio feature extraction script.
|
||||
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
|
||||
|
||||
Usage:
|
||||
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PLUGIN_DIR not in sys.path:
|
||||
sys.path.insert(0, _PLUGIN_DIR)
|
||||
|
||||
|
||||
def _step(n, total, label):
|
||||
"""Print step header and return start time."""
|
||||
print(f"[extract] Step {n}/{total} — {label}...", flush=True)
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def _done(t0, extra=""):
|
||||
elapsed = time.perf_counter() - t0
|
||||
suffix = f" {extra}" if extra else ""
|
||||
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
|
||||
|
||||
|
||||
def main():
|
||||
t_total = time.perf_counter()
|
||||
|
||||
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
|
||||
parser.add_argument("--video", required=True, help="Path to input video")
|
||||
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
|
||||
parser.add_argument("--output", required=True, help="Output .npz path")
|
||||
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
|
||||
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
|
||||
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
|
||||
parser.add_argument("--clip_fps", type=float, default=4.0)
|
||||
parser.add_argument("--clip_size", type=int, default=288)
|
||||
parser.add_argument("--sync_fps", type=float, default=25.0)
|
||||
parser.add_argument("--sync_size", type=int, default=224)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"[extract] Python : {sys.executable}", flush=True)
|
||||
print(f"[extract] Video : {args.video}", flush=True)
|
||||
print(f"[extract] Output : {args.output}", flush=True)
|
||||
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
|
||||
|
||||
if not os.path.exists(args.video):
|
||||
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(1, 6, "importing dependencies")
|
||||
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
|
||||
import torchvision.transforms as T
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
|
||||
feat_utils = FeaturesUtils(
|
||||
vae_config_path=args.vae_config,
|
||||
synchformer_ckpt=args.synchformer_ckpt,
|
||||
device=device,
|
||||
)
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(3, 6, "reading and preprocessing video")
|
||||
if args.video.endswith(".npy"):
|
||||
all_frames = np.load(args.video) # [T, H, W, C] uint8
|
||||
fps = args.source_fps
|
||||
total_frames = all_frames.shape[0]
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = all_frames[clip_indices]
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = all_frames[sync_indices]
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
else:
|
||||
from decord import VideoReader, cpu
|
||||
vr = VideoReader(args.video, ctx=cpu(0))
|
||||
fps = vr.get_avg_fps()
|
||||
total_frames = len(vr)
|
||||
duration = total_frames / fps
|
||||
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
|
||||
|
||||
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
|
||||
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
|
||||
clip_frames = vr.get_batch(clip_indices).asnumpy()
|
||||
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
|
||||
|
||||
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
|
||||
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
|
||||
sync_frames = vr.get_batch(sync_indices).asnumpy()
|
||||
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
|
||||
|
||||
clip_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(args.clip_size),
|
||||
T.CenterCrop(args.clip_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
|
||||
|
||||
sync_transform = T.Compose([
|
||||
T.ToPILImage(),
|
||||
T.Resize(args.sync_size),
|
||||
T.CenterCrop(args.sync_size),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
])
|
||||
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
|
||||
_done(t0)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(4, 6, "encoding text with T5-Gemma")
|
||||
text_features = feat_utils.encode_t5_text([args.cot_text])
|
||||
_done(t0, f"shape={tuple(text_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(5, 6, "encoding video with VideoPrism")
|
||||
global_video_features, video_features, global_text_features = \
|
||||
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
|
||||
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = _step(6, 6, "encoding video with Synchformer")
|
||||
sync_features = feat_utils.encode_video_with_sync(sync_input)
|
||||
_done(t0, f"shape={tuple(sync_features.shape)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
t0 = time.perf_counter()
|
||||
print(f"[extract] Saving features to {args.output} ...", flush=True)
|
||||
np.savez(
|
||||
args.output,
|
||||
video_features=video_features.cpu().float().numpy(),
|
||||
global_video_features=global_video_features.cpu().float().numpy(),
|
||||
text_features=text_features.cpu().float().numpy(),
|
||||
global_text_features=global_text_features.cpu().float().numpy(),
|
||||
sync_features=sync_features.cpu().float().numpy(),
|
||||
caption_cot=args.cot_text,
|
||||
duration=duration,
|
||||
)
|
||||
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
|
||||
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,44 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
# Install the PrismAudio feature-extraction environment using pip venv.
|
||||
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/install_extract_env.sh [/path/to/venv]
|
||||
#
|
||||
# Default venv path: /opt/prismaudio-extract
|
||||
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
|
||||
# <venv>/bin/python (Linux/Mac)
|
||||
# <venv>\Scripts\python.exe (Windows)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
VENV_DIR="${1:-/opt/prismaudio-extract}"
|
||||
|
||||
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
|
||||
python3 -m venv "${VENV_DIR}"
|
||||
|
||||
PIP="${VENV_DIR}/bin/pip"
|
||||
|
||||
echo "[PrismAudio] Upgrading pip..."
|
||||
"${PIP}" install --upgrade pip
|
||||
|
||||
echo "[PrismAudio] Installing PyTorch stack..."
|
||||
"${PIP}" install torch torchaudio torchvision
|
||||
|
||||
echo "[PrismAudio] Installing feature-extraction dependencies..."
|
||||
"${PIP}" install \
|
||||
"tensorflow-cpu>=2.16.0" \
|
||||
"jax[cpu]" \
|
||||
"jaxlib" \
|
||||
"transformers" \
|
||||
"decord" \
|
||||
"einops" \
|
||||
"numpy" \
|
||||
"mediapy"
|
||||
|
||||
echo "[PrismAudio] Installing VideoPrism..."
|
||||
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
|
||||
|
||||
echo ""
|
||||
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
|
||||
echo " ${VENV_DIR}/bin/python"
|
||||
@@ -1,158 +0,0 @@
|
||||
{
|
||||
"id": "a1c3e5f7-b2d4-4e6a-8c0f-1a3b5c7d9e2f",
|
||||
"revision": 0,
|
||||
"last_node_id": 3,
|
||||
"last_link_id": 2,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "PrismAudioModelLoader",
|
||||
"pos": [
|
||||
-160,
|
||||
-224
|
||||
],
|
||||
"size": [
|
||||
288,
|
||||
96
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "PRISMAUDIO_MODEL",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
1
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"aux_id": "ethanfel/ComfyUI-Prismaudio",
|
||||
"ver": "62a3c5d",
|
||||
"Node name for S&R": "PrismAudioModelLoader",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.8",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"auto",
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "PrismAudioTextOnly",
|
||||
"pos": [
|
||||
192,
|
||||
-224
|
||||
],
|
||||
"size": [
|
||||
480,
|
||||
222
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "PRISMAUDIO_MODEL",
|
||||
"link": 1
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
2
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"aux_id": "ethanfel/ComfyUI-Prismaudio",
|
||||
"ver": "62a3c5d",
|
||||
"Node name for S&R": "PrismAudioTextOnly",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.8",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"A large dog barks sharply twice in an outdoor setting, with ambient background noise of rustling leaves and a gentle breeze. The sound is clear and close, recorded at ground level.",
|
||||
10.0,
|
||||
100,
|
||||
7.0,
|
||||
0,
|
||||
"randomize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "PreviewAudio",
|
||||
"pos": [
|
||||
736,
|
||||
-224
|
||||
],
|
||||
"size": [
|
||||
300,
|
||||
76
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"link": 2
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"properties": {
|
||||
"Node name for S&R": "PreviewAudio"
|
||||
},
|
||||
"widgets_values": []
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
[
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
2,
|
||||
0,
|
||||
"PRISMAUDIO_MODEL"
|
||||
],
|
||||
[
|
||||
2,
|
||||
2,
|
||||
0,
|
||||
3,
|
||||
0,
|
||||
"AUDIO"
|
||||
]
|
||||
],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1.1674071890328979,
|
||||
"offset": [
|
||||
1814.5534800416863,
|
||||
500.0421331448515
|
||||
]
|
||||
},
|
||||
"ue_links": [],
|
||||
"links_added_by_ue": [],
|
||||
"frontendVersion": "1.42.8"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
@@ -1,421 +0,0 @@
|
||||
{
|
||||
"id": "2481bfbf-ce24-46c5-abdc-1d9163ff78ae",
|
||||
"revision": 0,
|
||||
"last_node_id": 12,
|
||||
"last_link_id": 30,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "VHS_LoadVideo",
|
||||
"pos": [
|
||||
-704,
|
||||
-256
|
||||
],
|
||||
"size": [
|
||||
288,
|
||||
474.7188081936685
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "meta_batch",
|
||||
"shape": 7,
|
||||
"type": "VHS_BatchManager",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "vae",
|
||||
"shape": 7,
|
||||
"type": "VAE",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
12,
|
||||
20
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "frame_count",
|
||||
"type": "INT",
|
||||
"slot_index": 1,
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"slot_index": 2,
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "video_info",
|
||||
"type": "VHS_VIDEOINFO",
|
||||
"slot_index": 3,
|
||||
"links": [
|
||||
21
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfyui-videohelpersuite",
|
||||
"ver": "1.7.9",
|
||||
"Node name for S&R": "VHS_LoadVideo",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.8",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": {
|
||||
"video": "Railtransport_3_479.mp4",
|
||||
"force_rate": 0,
|
||||
"custom_width": 0,
|
||||
"custom_height": 0,
|
||||
"frame_load_cap": 0,
|
||||
"skip_first_frames": 0,
|
||||
"select_every_nth": 1,
|
||||
"format": "AnimateDiff",
|
||||
"videopreview": {
|
||||
"hidden": false,
|
||||
"paused": false,
|
||||
"params": {
|
||||
"force_rate": 0,
|
||||
"frame_load_cap": 0,
|
||||
"skip_first_frames": 0,
|
||||
"select_every_nth": 1,
|
||||
"filename": "Railtransport_3_479.mp4",
|
||||
"type": "input",
|
||||
"format": "video/mp4"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "PrismAudioModelLoader",
|
||||
"pos": [
|
||||
-160,
|
||||
-224
|
||||
],
|
||||
"size": [
|
||||
288,
|
||||
96
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "PRISMAUDIO_MODEL",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
26
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"aux_id": "ethanfel/ComfyUI-Prismaudio",
|
||||
"ver": "3894fcc9b40a19d959614d514d5dff65cdfb6eab",
|
||||
"Node name for S&R": "PrismAudioModelLoader",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"version": "7.8",
|
||||
"input_ue_unconnectable": {}
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"auto",
|
||||
"auto"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "PrismAudioSampler",
|
||||
"pos": [
|
||||
256,
|
||||
-224
|
||||
],
|
||||
"size": [
|
||||
384,
|
||||
224
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "model",
|
||||
"type": "PRISMAUDIO_MODEL",
|
||||
"link": 26
|
||||
},
|
||||
{
|
||||
"name": "features",
|
||||
"type": "PRISMAUDIO_FEATURES",
|
||||
"link": 27
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"links": [
|
||||
29
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"aux_id": "ethanfel/ComfyUI-Prismaudio",
|
||||
"ver": "30631c0cb4d97cc6aed69a52e3ee4d89df03926c",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {}
|
||||
},
|
||||
"Node name for S&R": "PrismAudioSampler"
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
100,
|
||||
7,
|
||||
4096333446,
|
||||
"randomize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "PrismAudioFeatureExtractor",
|
||||
"pos": [
|
||||
-384,
|
||||
-64
|
||||
],
|
||||
"size": [
|
||||
544,
|
||||
288
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "video",
|
||||
"type": "IMAGE",
|
||||
"link": 20
|
||||
},
|
||||
{
|
||||
"name": "video_info",
|
||||
"shape": 7,
|
||||
"type": "VHS_VIDEOINFO",
|
||||
"link": 21
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "features",
|
||||
"type": "PRISMAUDIO_FEATURES",
|
||||
"links": [
|
||||
27
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"links": [
|
||||
30
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"aux_id": "ethanfel/ComfyUI-Prismaudio",
|
||||
"ver": "5b62be04471bf118b2cd3cc71431a302f5730b01",
|
||||
"Node name for S&R": "PrismAudioFeatureExtractor",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {},
|
||||
"version": "7.8"
|
||||
}
|
||||
},
|
||||
"widgets_values": [
|
||||
"Generate ambient countryside sounds with a gentle breeze rustling the leaves of a large tree. From the right, introduce a faint rumble of wheels on a track and a steam engine chugging. Allow the sounds to grow louder and pan from right to left as the steam train travels across the landscape. Include the powerful chugging and clattering of carriages in the soundscape, then gradually recede the sounds to the left. Ensure no additional background noise or music is present.\n",
|
||||
30,
|
||||
"managed_env",
|
||||
"/media/unraid/comfyui/output/prismaudiocache/",
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "VHS_VideoCombine",
|
||||
"pos": [
|
||||
704,
|
||||
-256
|
||||
],
|
||||
"size": [
|
||||
384,
|
||||
552.75
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"link": 12
|
||||
},
|
||||
{
|
||||
"name": "audio",
|
||||
"shape": 7,
|
||||
"type": "AUDIO",
|
||||
"link": 29
|
||||
},
|
||||
{
|
||||
"name": "meta_batch",
|
||||
"shape": 7,
|
||||
"type": "VHS_BatchManager",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "vae",
|
||||
"shape": 7,
|
||||
"type": "VAE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "frame_rate",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "frame_rate"
|
||||
},
|
||||
"link": 30
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "Filenames",
|
||||
"type": "VHS_FILENAMES",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"cnr_id": "comfyui-videohelpersuite",
|
||||
"ver": "1.7.9",
|
||||
"Node name for S&R": "VHS_VideoCombine",
|
||||
"ue_properties": {
|
||||
"widget_ue_connectable": {},
|
||||
"input_ue_unconnectable": {},
|
||||
"version": "7.8"
|
||||
}
|
||||
},
|
||||
"widgets_values": {
|
||||
"frame_rate": 30,
|
||||
"loop_count": 0,
|
||||
"filename_prefix": "AnimateDiff",
|
||||
"format": "video/h264-mp4",
|
||||
"pix_fmt": "yuv420p",
|
||||
"crf": 19,
|
||||
"save_metadata": true,
|
||||
"trim_to_audio": false,
|
||||
"pingpong": false,
|
||||
"save_output": false,
|
||||
"videopreview": {
|
||||
"hidden": false,
|
||||
"paused": false,
|
||||
"params": {
|
||||
"filename": "AnimateDiff_00001-audio.mp4",
|
||||
"subfolder": "",
|
||||
"type": "temp",
|
||||
"format": "video/h264-mp4",
|
||||
"frame_rate": 30,
|
||||
"workflow": "AnimateDiff_00001.png",
|
||||
"fullpath": "/basedir/temp/AnimateDiff_00001-audio.mp4"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
[
|
||||
12,
|
||||
1,
|
||||
0,
|
||||
9,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
20,
|
||||
1,
|
||||
0,
|
||||
11,
|
||||
0,
|
||||
"IMAGE"
|
||||
],
|
||||
[
|
||||
21,
|
||||
1,
|
||||
3,
|
||||
11,
|
||||
1,
|
||||
"VHS_VIDEOINFO"
|
||||
],
|
||||
[
|
||||
26,
|
||||
2,
|
||||
0,
|
||||
12,
|
||||
0,
|
||||
"PRISMAUDIO_MODEL"
|
||||
],
|
||||
[
|
||||
27,
|
||||
11,
|
||||
0,
|
||||
12,
|
||||
1,
|
||||
"PRISMAUDIO_FEATURES"
|
||||
],
|
||||
[
|
||||
29,
|
||||
12,
|
||||
0,
|
||||
9,
|
||||
1,
|
||||
"AUDIO"
|
||||
],
|
||||
[
|
||||
30,
|
||||
11,
|
||||
1,
|
||||
9,
|
||||
4,
|
||||
"FLOAT"
|
||||
]
|
||||
],
|
||||
"groups": [],
|
||||
"config": {},
|
||||
"extra": {
|
||||
"ds": {
|
||||
"scale": 1.1674071890328979,
|
||||
"offset": [
|
||||
1814.5534800416863,
|
||||
500.0421331448515
|
||||
]
|
||||
},
|
||||
"ue_links": [],
|
||||
"links_added_by_ue": [],
|
||||
"frontendVersion": "1.42.8",
|
||||
"VHS_latentpreview": true,
|
||||
"VHS_latentpreviewrate": 0,
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
Reference in New Issue
Block a user