chore: remove all PrismAudio code from main branch

- Delete prismaudio_core/, data_utils/, scripts/, docs/plans/
- Delete PrismAudio nodes (feature_extractor, feature_loader, model_loader, sampler, text_only)
- Delete PrismAudio workflows (video_to_audio, text_to_audio)
- Clean nodes/utils.py: rename PRISMAUDIO_CATEGORY → SELVA_CATEGORY, remove unused helpers
- Strip PrismAudio-only deps from requirements.txt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-04 17:58:31 +02:00
parent 679a607a85
commit 83b1da9520
43 changed files with 11 additions and 11958 deletions
+1 -1
View File
@@ -1,5 +1,5 @@
"""
ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026).
ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio.
"""
import sys
import os
View File
View File
-337
View File
@@ -1,337 +0,0 @@
"""
PrismAudio feature extraction utilities.
Implements FeaturesUtils used by scripts/extract_features.py to extract:
- Text features via T5-Gemma (transformers)
- Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism)
- Sync features via Synchformer visual encoder (PyTorch)
"""
import os
import torch
import torch.nn as nn
import numpy as np
class FeaturesUtils:
def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None):
self.device = device or torch.device("cpu")
self._t5_tokenizer = None
self._t5_encoder = None
self._vp_model = None
self._vp_state = None
self._vp_text_tokenizer = None
self._sync_model = None
self._synchformer_ckpt = synchformer_ckpt
self._load_synchformer()
# ------------------------------------------------------------------
# T5-Gemma text encoding
# ------------------------------------------------------------------
def _ensure_t5(self):
if self._t5_encoder is not None:
return
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id = "google/t5gemma-l-l-ul2-it"
print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}")
self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
self._t5_encoder = (
AutoModelForSeq2SeqLM.from_pretrained(model_id)
.get_encoder()
.to(self.device)
.eval()
)
def encode_t5_text(self, texts):
"""
Args:
texts: list of str
Returns:
Tensor [seq_len, 1024]
"""
self._ensure_t5()
tokens = self._t5_tokenizer(
texts, return_tensors="pt", padding=True
).to(self.device)
with torch.no_grad():
out = self._t5_encoder(**tokens)
# Move encoder off GPU to save VRAM
self._t5_encoder.to("cpu")
torch.cuda.empty_cache()
return out.last_hidden_state.squeeze(0) # [seq_len, 1024]
# ------------------------------------------------------------------
# VideoPrism video + text encoding (JAX)
# ------------------------------------------------------------------
def _ensure_videoprism(self):
if self._vp_model is not None:
return
from videoprism import models as vp
import jax
model_name = "videoprism_lvt_public_v1_large"
print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...")
self._vp_model = vp.get_model(model_name)
self._vp_state = vp.load_pretrained_weights(model_name)
self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en")
jax_dev = jax.devices()[0]
self._jax_forward = jax.jit(
lambda x, y, z: self._vp_model.apply(
self._vp_state, x, y, z, train=False, return_intermediate=True
),
device=jax_dev,
)
def encode_video_and_text_with_videoprism(self, clip_input, texts):
"""
Args:
clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
texts: list of str — CoT captions, passed to VideoPrism LvT text tower
Returns:
global_video_features: Tensor [1, D]
video_features: Tensor [T, D] — per-frame L2-normalized embeddings
global_text_features: Tensor [1, D]
"""
self._ensure_videoprism()
import jax.numpy as jnp
from videoprism import models as vp
# Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array
frames = clip_input.squeeze(0) # [T, C, H, W]
frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1]
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
frames_np = frames.cpu().numpy().astype(np.float32)
frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C]
# Tokenize text (padding value 1.0 = pad, 0.0 = real token)
text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts)
# Joint video+text forward with intermediate outputs
video_embeddings, text_embeddings, outputs = self._jax_forward(
frames_jax, text_ids, text_paddings
)
# Per-frame features: [B, T, 1024] L2-normalized
frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024]
per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024]
# Global video embedding: [1024] → [1, 1024]
global_video = torch.from_numpy(
np.array(video_embeddings[0])
).unsqueeze(0).to(self.device) # [1, 1024]
# Global text embedding: [1024] → [1, 1024]
global_text = torch.from_numpy(
np.array(text_embeddings[0])
).unsqueeze(0).to(self.device) # [1, 1024]
return global_video, per_frame, global_text
# ------------------------------------------------------------------
# Synchformer sync feature encoding
# ------------------------------------------------------------------
def _load_synchformer(self):
if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt):
return
print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}")
state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False)
# Checkpoint may be raw state_dict or wrapped in {"model": ...}
if isinstance(state, dict) and "model" in state:
state_dict = state["model"]
else:
state_dict = state
self._sync_model = _SynchformerVisualEncoder(state_dict, self.device)
self._sync_model.eval()
def encode_video_with_sync(self, sync_input):
"""
Args:
sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1]
Returns:
sync_features: Tensor [num_segments, 768]
"""
if self._sync_model is None:
raise RuntimeError(
"[FeaturesUtils] Synchformer checkpoint not loaded. "
"Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt."
)
frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W]
with torch.no_grad():
return self._sync_model(frames)
# ------------------------------------------------------------------
# Synchformer visual encoder — TimeSformer-style ViT-B/16
# Architecture reverse-engineered from synchformer_state_dict.pth
# ------------------------------------------------------------------
import torch.nn.functional as F
class _PatchEmbed(nn.Module):
"""2D patch embedding: [B, 3, 224, 224] → [B, 196, 768]."""
def __init__(self):
super().__init__()
self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
def forward(self, x):
return self.proj(x).flatten(2).transpose(1, 2)
class _ViTAttn(nn.Module):
"""ViT-style QKV attention (timm convention: qkv as single Linear)."""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1)
return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D))
class _BlockMLP(nn.Module):
"""Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint."""
def __init__(self, dim=768, mlp_dim=3072):
super().__init__()
self.fc1 = nn.Linear(dim, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, dim)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _TimeSformerBlock(nn.Module):
"""
Factorized space-time attention block.
norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = _ViTAttn(dim, num_heads)
self.norm3 = nn.LayerNorm(dim)
self.timeattn = _ViTAttn(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = _BlockMLP(dim)
def forward(self, x, T):
# x: [T, N, D] (T frames treated as batch, N=197 spatial tokens)
x = x + self.attn(self.norm1(x))
# Temporal attention: for each spatial position, attend across T frames
# [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D]
xt = x.permute(1, 0, 2)
xt = xt + self.timeattn(self.norm3(xt))
x = xt.permute(1, 0, 2)
x = x + self.mlp(self.norm2(x))
return x
class _SpatialAttnAgg(nn.Module):
"""
Aggregates 196 spatial patches → 1 feature per frame using a
TransformerEncoderLayer with a learnable CLS token.
Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2.
"""
def __init__(self, dim=768, num_heads=12):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.linear1 = nn.Linear(dim, dim * 4)
self.linear2 = nn.Linear(dim * 4, dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# x: [T, 196, 768] — spatial patches (CLS stripped)
T = x.shape[0]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [T, 197, 768]
xn = self.norm1(x)
x = x + self.self_attn(xn, xn, xn)[0]
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x[:, 0, :] # [T, 768] — CLS per frame
class _SynchformerVisualEncoder(nn.Module):
"""
TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint.
Processes video in segments of 8 frames → [T_aligned, 768] per-frame features.
"""
def __init__(self, state_dict, device):
super().__init__()
self.device = device
self.segment_frames = 8
self.patch_embed = _PatchEmbed()
self.cls_token = nn.Parameter(torch.zeros(1, 1, 768))
self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768))
self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768))
self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)])
self.norm = nn.LayerNorm(768)
self.spatial_attn_agg = _SpatialAttnAgg()
# Load weights from vfeat_extractor.* prefix
prefix = "vfeat_extractor."
sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
# Exclude 3D patch embed (we use 2D only)
sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")}
missing, unexpected = self.load_state_dict(sub, strict=False)
print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}")
if missing:
print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}")
self.to(device)
def forward(self, frames):
"""
Args:
frames: [T, C, H, W] float32 in [-1, 1], at 25fps
Returns:
[T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8)
"""
T = frames.shape[0]
seg = self.segment_frames
num_seg = max(1, T // seg)
T_aligned = num_seg * seg
results = []
for i in range(num_seg):
chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W]
results.append(self._forward_segment(chunk))
return torch.cat(results, dim=0) # [T_aligned, 768]
def _forward_segment(self, x):
# x: [8, 3, 224, 224]
T = x.shape[0] # 8
# Patch embedding + CLS token
x = self.patch_embed(x) # [8, 196, 768]
cls = self.cls_token.expand(T, -1, -1)
x = torch.cat([cls, x], dim=1) # [8, 197, 768]
# Positional + temporal embeddings
x = x + self.pos_embed # broadcast (1,197,768)
x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast
# Transformer blocks (factorized space-time)
for block in self.blocks:
x = block(x, T)
x = self.norm(x)
# Aggregate spatial patches → 1 feature per frame
return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768]
@@ -1,194 +0,0 @@
# ComfyUI-PrismAudio Design Document
**Date:** 2026-03-27
**Status:** Approved
## Overview
ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE.
## Architecture
**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment.
## Project Structure
```
ComfyUI-PrismAudio/
├── __init__.py # Node registration
├── nodes/
│ ├── __init__.py
│ ├── model_loader.py # PrismAudioModelLoader
│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz)
│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge)
│ ├── sampler.py # PrismAudioSampler
│ ├── text_only.py # PrismAudioTextOnly
│ └── utils.py # Shared helpers
├── prismaudio_core/ # Extracted inference code from PrismAudio
│ ├── __init__.py
│ ├── configs/
│ │ └── prismaudio.json
│ ├── models/ # DiT, conditioners, autoencoders, etc.
│ ├── inference/ # sampling.py, generation.py
│ └── factory.py # create_model_from_config
├── scripts/
│ ├── extract_features.py # Standalone VideoPrism feature extraction
│ └── environment.yml # Conda env for extraction (JAX + TF)
├── requirements.txt # PyTorch-only deps (no JAX/TF)
└── README.md
```
## Nodes
### PrismAudioModelLoader
Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability |
| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] |
| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | |
| **Output** | | |
| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config |
**Token resolution order** (no widget — env/CLI only for security):
1. `HF_TOKEN` environment variable
2. `huggingface-cli login` cached token
3. None — fails on gated models with clear error message linking to license page
**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually.
### PrismAudioFeatureLoader
Loads pre-computed `.npz` feature files for maximum quality video-to-audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| npz_path | STRING | Path to .npz file |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features |
### PrismAudioFeatureExtractor
Subprocess bridge — extracts features from video using VideoPrism in an isolated environment.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| video | IMAGE | ComfyUI video frames tensor |
| caption_cot | STRING | CoT description text |
| python_env | STRING | Path to python binary with JAX/TF (default: "python") |
| output_dir | STRING | Cache directory for .npz files (default: temp dir) |
| **Output** | | |
| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output |
**Caching:** Hashes video + text to avoid re-extraction on repeated runs.
### PrismAudioSampler
Main generation node — takes model + features, produces audio.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor |
| cot_description | STRING | Multiline CoT text |
| duration | FLOAT | 1.0-30.0, defaults to video length |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
**Pipeline:**
1. Encode CoT text via T5-Gemma -> text_features
2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond)
3. Compute latent_seq_len = round(44100 / 2048 * duration)
4. Generate noise [1, 64, latent_seq_len] from seed
5. Discrete Euler sampling (rectified flow) with CFG
6. VAE decode -> stereo waveform at 44100 Hz
7. Normalize to [-1, 1], return as AUDIO
### PrismAudioTextOnly
Text-to-audio without video input.
| Field | Type | Details |
|-------|------|---------|
| **Inputs** | | |
| model | PRISMAUDIO_MODEL | From ModelLoader |
| text_prompt | STRING | Text description |
| duration | FLOAT | 1.0-30.0 |
| steps | INT | 1-100, default 24 |
| cfg_scale | FLOAT | 1.0-20.0, default 5.0 |
| seed | INT | Controls noise generation |
| **Output** | | |
| audio | AUDIO | {waveform: tensor, sample_rate: 44100} |
Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt.
## VRAM Management
Adaptive strategy using `comfy.model_management`:
| Available VRAM | Behavior |
|---|---|
| 24GB+ | Keep diffusion + VAE in VRAM |
| 12-24GB | Sequential offload between stages |
| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced |
| <8GB | Warn user, attempt with aggressive offload + fp16 |
Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()`
## Feature Extraction Paths
### Path 1: Pre-computed .npz (FeatureLoader)
User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk.
### Path 2: Subprocess bridge (FeatureExtractor)
Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash.
### Path 3: Text-only (TextOnly node)
No video features needed. T5-Gemma text encoding only (PyTorch-native).
## Dependencies
### ComfyUI environment (`requirements.txt`)
```
einops>=0.7.0
safetensors
huggingface_hub
transformers>=4.52.3
k-diffusion>=0.1.1
```
flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`.
### Extraction environment (`scripts/environment.yml`)
Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup.
## Model Files
Stored in `ComfyUI/models/prismaudio/`:
| File | Size | Source |
|------|------|--------|
| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio |
| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio |
| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio |
T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache.
Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)`
## Design Decisions
- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes.
- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env.
- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL.
- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`.
- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback.
File diff suppressed because it is too large Load Diff
-207
View File
@@ -1,207 +0,0 @@
import os
import sys
import hashlib
import subprocess
import tempfile
import torch
from .utils import PRISMAUDIO_CATEGORY
from .feature_loader import PrismAudioFeatureLoader
# Managed venv created automatically when python_env is left as default
_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__))
_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env")
_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python")
_EXTRACT_PACKAGES = [
"torch", "torchaudio", "torchvision",
# TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+
"tensorflow-cpu>=2.16.0",
# jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed)
"jax[cuda13]", "flax",
"transformers", "decord", "einops", "numpy", "mediapy",
"git+https://github.com/google-deepmind/videoprism.git",
]
def _pip_install(pip, *packages, label=None):
"""Install one or more packages with visible output; raise on failure."""
tag = label or packages[0]
print(f"[PrismAudio] installing {tag} ...", flush=True)
result = subprocess.run(
[pip, "install", "--progress-bar", "on"] + list(packages),
capture_output=False,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). "
"See pip output above for details."
)
print(f"[PrismAudio] {tag} OK", flush=True)
def _ensure_extract_env():
"""Create and populate the managed venv on first use."""
if os.path.exists(_MANAGED_PYTHON):
return _MANAGED_PYTHON
import shutil
if os.path.exists(_MANAGED_VENV):
print("[PrismAudio] Removing incomplete venv and retrying...", flush=True)
shutil.rmtree(_MANAGED_VENV)
print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True)
subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True)
pip = os.path.join(_MANAGED_VENV, "bin", "pip")
print("[PrismAudio] Upgrading pip...", flush=True)
subprocess.run([pip, "install", "--upgrade", "pip"], check=True)
total = len(_EXTRACT_PACKAGES)
print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True)
for i, pkg in enumerate(_EXTRACT_PACKAGES, 1):
label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0]
print(f"[PrismAudio] [{i}/{total}] {label}", flush=True)
_pip_install(pip, pkg, label=label)
print("[PrismAudio] Feature-extraction env ready.", flush=True)
return _MANAGED_PYTHON
def _hash_inputs(video_tensor, cot_text):
"""Create a hash of the inputs for caching."""
h = hashlib.sha256()
h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed
h.update(cot_text.encode())
return h.hexdigest()[:16]
def _save_frames_to_npy(video_tensor, output_path):
"""Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8.
Lossless — avoids H.264 encode/decode roundtrip.
"""
import numpy as np
frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8")
np.save(output_path, frames_np)
class PrismAudioFeatureExtractor:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": ("IMAGE",),
"caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}),
},
"optional": {
"video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}),
"python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}),
"cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}),
"hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT")
RETURN_NAMES = ("features", "fps")
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""):
# Resolve fps from VHS video_info if connected
if video_info is not None:
fps = video_info["loaded_fps"]
# Resolve python binary
if python_env == "comfyui_env":
print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. "
"Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True)
python_bin = sys.executable
else:
python_bin = _ensure_extract_env()
# Determine cache directory
if not cache_dir:
cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features")
os.makedirs(cache_dir, exist_ok=True)
# Check cache
cache_hash = _hash_inputs(video, caption_cot)
cached_path = os.path.join(cache_dir, f"{cache_hash}.npz")
if os.path.exists(cached_path):
print(f"[PrismAudio] Using cached features: {cached_path}")
loader = PrismAudioFeatureLoader()
features, = loader.load_features(cached_path)
return (features, float(fps))
# Save frames to temp file (lossless .npy, no codec roundtrip)
import time
t0 = time.perf_counter()
frames = video.shape[0]
print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True)
with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp:
tmp_video = tmp.name
_save_frames_to_npy(video, tmp_video)
print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True)
# Build subprocess command
script_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"scripts", "extract_features.py"
)
import folder_paths
synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth")
if not os.path.exists(synchformer_ckpt):
raise RuntimeError(
f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n"
"Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/."
)
cmd = [
python_bin,
script_path,
"--video", tmp_video,
"--cot_text", caption_cot,
"--output", cached_path,
"--source_fps", str(fps),
"--synchformer_ckpt", synchformer_ckpt,
]
# Build env: inherit current env, inject HF token if provided
import copy
env = copy.copy(os.environ)
token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "")
if token:
env["HF_TOKEN"] = token
env["HUGGING_FACE_HUB_TOKEN"] = token
else:
print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. "
"Add your token in the hf_token input or set HF_TOKEN env var.", flush=True)
print(f"[PrismAudio] Extracting features via subprocess (output streams live)...")
try:
# capture_output=False: let stdout/stderr stream directly to ComfyUI logs
result = subprocess.run(
cmd,
capture_output=False,
timeout=600, # 10 minute timeout
env=env,
)
if result.returncode != 0:
raise RuntimeError(
f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. "
"See output above for details."
)
print("[PrismAudio] Feature extraction subprocess finished successfully.")
finally:
if os.path.exists(tmp_video):
os.unlink(tmp_video)
# Load the extracted features
loader = PrismAudioFeatureLoader()
features, = loader.load_features(cached_path)
return (features, float(fps))
-53
View File
@@ -1,53 +0,0 @@
import os
import numpy as np
import torch
from .utils import PRISMAUDIO_CATEGORY
# Keys consumed by the conditioners (video_features, text_features, sync_features)
# global_video_features and global_text_features are NOT consumed by any conditioner
# in the prismaudio.json config — they are unused.
REQUIRED_KEYS = [
"video_features",
"text_features",
"sync_features",
]
class PrismAudioFeatureLoader:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}),
},
}
RETURN_TYPES = ("PRISMAUDIO_FEATURES",)
RETURN_NAMES = ("features",)
FUNCTION = "load_features"
CATEGORY = PRISMAUDIO_CATEGORY
def load_features(self, npz_path):
if not os.path.exists(npz_path):
raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}")
data = np.load(npz_path, allow_pickle=True)
features = {}
for key in REQUIRED_KEYS:
if key in data:
features[key] = torch.from_numpy(data[key]).float()
else:
print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros")
# Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None
# Sync_MLP requires length divisible by 8 (segments of 8 frames)
if key == "sync_features":
features[key] = torch.zeros(8, 768)
else:
features[key] = torch.zeros(1, 1024)
# Load duration if present
if "duration" in data:
features["duration"] = float(data["duration"])
return (features,)
-154
View File
@@ -1,154 +0,0 @@
import os
import json
import torch
import folder_paths
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder,
get_device, get_offload_device, determine_precision, determine_offload_strategy,
soft_empty_cache, resolve_hf_token,
)
# HuggingFace repo for auto-download
HF_REPO_ID = "FunAudioLLM/PrismAudio"
REQUIRED_FILES = {
"diffusion": "prismaudio.ckpt",
"vae": "vae.ckpt",
"synchformer": "synchformer_state_dict.pth",
}
def _download_if_missing(filename, model_dir, hf_token=None):
"""Download a model file from HuggingFace if not present locally."""
filepath = os.path.join(model_dir, filename)
if os.path.exists(filepath):
return filepath
from huggingface_hub import hf_hub_download
print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...")
try:
downloaded = hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=model_dir,
token=hf_token or None,
)
return downloaded
except Exception as e:
if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower():
raise RuntimeError(
f"[PrismAudio] Model '{filename}' requires license acceptance. "
f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, "
f"then set HF_TOKEN env var or run: huggingface-cli login"
) from e
raise
class PrismAudioModelLoader:
@classmethod
def INPUT_TYPES(cls):
register_model_folder()
return {
"required": {
"precision": (["auto", "fp32", "fp16", "bf16"],),
"offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],),
},
}
RETURN_TYPES = ("PRISMAUDIO_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = PRISMAUDIO_CATEGORY
def load_model(self, precision, offload_strategy):
device = get_device()
dtype = determine_precision(precision, device)
strategy = determine_offload_strategy(offload_strategy)
token = resolve_hf_token()
model_dir = get_prismaudio_model_dir()
# Auto-download missing files
for key, filename in REQUIRED_FILES.items():
_download_if_missing(filename, model_dir, hf_token=token)
# Load config
config_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"prismaudio_core", "configs", "prismaudio.json"
)
with open(config_path) as f:
model_config = json.load(f)
# Create model from config
from prismaudio_core.factory import create_model_from_config
model = create_model_from_config(model_config)
# Load diffusion weights
diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"])
diffusion_state = comfy.utils.load_torch_file(diffusion_path)
# Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...}
if "state_dict" in diffusion_state:
diffusion_state = diffusion_state["state_dict"]
diff_result = model.load_state_dict(diffusion_state, strict=False)
print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True)
print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True)
if diff_result.missing_keys:
print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True)
if diff_result.unexpected_keys:
print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True)
# Sample a few ckpt keys to verify prefix alignment
sample_keys = list(diffusion_state.keys())[:5]
print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True)
# Load VAE weights separately
# Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat
vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"])
vae_full_state = comfy.utils.load_torch_file(vae_path)
print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True)
# Sample raw keys to see actual prefix
vae_sample_keys = list(vae_full_state.keys())[:8]
print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True)
# Strip "autoencoder." prefix from keys
vae_state = {}
prefix = "autoencoder."
for k, v in vae_full_state.items():
if k.startswith(prefix):
vae_state[k[len(prefix):]] = v
else:
vae_state[k] = v
print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True)
# Sample model keys to compare
model_vae_keys = list(model.pretransform.state_dict().keys())[:5]
print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True)
# strict=False: vae.ckpt is a training checkpoint that also contains
# discriminator, loss modules, and EMA wrappers not present in the
# inference AudioAutoencoder — ignore those extra keys.
# Load directly into the inner AudioAutoencoder to get IncompatibleKeys back
# (AutoencoderPretransform.load_state_dict doesn't return the result)
vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False)
print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True)
if vae_result.missing_keys:
print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True)
# Apply precision: DiT + conditioners in user-selected dtype,
# but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16
model.model.to(dtype) # DiTWrapper
model.conditioner.to(dtype) # MultiConditioner
# model.pretransform stays in fp32
if strategy == "keep_in_vram":
model = model.to(device)
else:
model = model.to(get_offload_device())
model.eval()
return ({
"model": model,
"dtype": dtype,
"strategy": strategy,
"config": model_config,
"model_dir": model_dir,
},)
-165
View File
@@ -1,165 +0,0 @@
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache,
)
class PrismAudioSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"features": ("PRISMAUDIO_FEATURES",),
"duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}),
"steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}),
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, features, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
# Resolve duration: 0 means use video duration from features
if duration <= 0:
if "duration" not in features:
raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.")
duration = features["duration"]
print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True)
# Compute latent dimensions
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Note: no seq length config needed — the model adapts to input tensor shapes
# dynamically via its transformer architecture.
# Determine if video features are present (not all zeros)
has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0
video_feat = features["video_features"].to(device, dtype=dtype)
sync_feat = features["sync_features"].to(device, dtype=dtype)
# Build metadata as a TUPLE of dicts (one per batch sample)
# MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this
sample_meta = {
"video_features": video_feat,
"text_features": features["text_features"].to(device, dtype=dtype),
"sync_features": sync_feat,
"video_exist": torch.tensor(has_video),
}
metadata = (sample_meta,)
# Move model to device if offloaded
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
# Run conditioning
conditioning = diffusion.conditioner(metadata, device)
# Handle missing video: substitute learned empty embeddings
if not has_video:
_substitute_empty_features(diffusion, conditioning, device, dtype)
# Assemble conditioning inputs for the DiT
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
# Sample with progress bar
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
fakes_f = fakes.float()
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
# Offload diffusion model and conditioner before VAE decode
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes_f)
# Offload VAE
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp (matching reference: div by max abs before clamp)
audio = audio.float()
pre_norm_std = audio.std().item()
pre_norm_peak = audio.abs().max().item()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
# Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int}
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
def _substitute_empty_features(diffusion, conditioning, device, dtype):
"""Replace video/sync conditioning with learned empty embeddings when video is absent.
empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner
output space (1024-dim). Passing zero features through bias-free Cond_MLP produces
near-zero activations, NOT the learned null signal the model was trained with.
The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim].
"""
dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model
# Substitute video_features with learned empty_clip_feat
if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning:
empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024]
batch_size = conditioning['video_features'][0].shape[0]
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
conditioning['video_features'][0] = empty_expanded
conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device)
# Substitute sync_features with learned empty_sync_feat
if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning:
empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024]
batch_size = conditioning['sync_features'][0].shape[0]
empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024]
conditioning['sync_features'][0] = empty_expanded
conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device)
+2 -2
View File
@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
# SelVA video preprocessing constants (from selva/utils/eval_utils.py)
_CLIP_SIZE = 384
@@ -68,7 +68,7 @@ class SelvaFeatureExtractor:
RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING")
RETURN_NAMES = ("features", "fps", "prompt")
FUNCTION = "extract_features"
CATEGORY = PRISMAUDIO_CATEGORY
CATEGORY = SELVA_CATEGORY
def extract_features(self, model, video, prompt, video_info=None, fps=30.0,
duration=0.0, cache_dir=""):
+2 -2
View File
@@ -3,7 +3,7 @@ from pathlib import Path
import torch
import folder_paths
from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy
from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy
# Variant → (generator filename, mode, has_bigvgan)
_VARIANTS = {
@@ -96,7 +96,7 @@ class SelvaModelLoader:
RETURN_TYPES = ("SELVA_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load_model"
CATEGORY = PRISMAUDIO_CATEGORY
CATEGORY = SELVA_CATEGORY
def load_model(self, variant, precision, offload_strategy):
from selva_core.model.networks_generator import get_my_mmaudio
+2 -2
View File
@@ -1,7 +1,7 @@
import torch
import comfy.utils
from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache
from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache
class SelvaSampler:
@@ -35,7 +35,7 @@ class SelvaSampler:
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
CATEGORY = SELVA_CATEGORY
def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed):
from selva_core.model.flow_matching import FlowMatching
-160
View File
@@ -1,160 +0,0 @@
import torch
import comfy.model_management as mm
import comfy.utils
from .utils import (
PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS,
get_device, get_offload_device, soft_empty_cache, resolve_hf_token,
)
from .sampler import _substitute_empty_features
class PrismAudioTextOnly:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("PRISMAUDIO_MODEL",),
"text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}),
"duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}),
"steps": ("INT", {"default": 100, "min": 1, "max": 100}),
"cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}),
},
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate"
CATEGORY = PRISMAUDIO_CATEGORY
def generate(self, model, text_prompt, duration, steps, cfg_scale, seed):
device = get_device()
dtype = model["dtype"]
strategy = model["strategy"]
diffusion = model["model"]
latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO)
# Encode text with T5-Gemma
text_features = _encode_text_t5(text_prompt, device, dtype)
# Build metadata: tuple of one dict per sample
# Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence)
# Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768]
# These will be substituted with learned empty embeddings after conditioning
sample_meta = {
"video_features": torch.zeros(1, 1024, device=device, dtype=dtype),
"text_features": text_features.to(device, dtype=dtype),
"sync_features": torch.zeros(8, 768, device=device, dtype=dtype),
"video_exist": torch.tensor(False),
}
metadata = (sample_meta,)
if strategy == "offload_to_cpu":
diffusion.model.to(device)
diffusion.conditioner.to(device)
soft_empty_cache()
with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype):
conditioning = diffusion.conditioner(metadata, device)
# Substitute empty features for video/sync
_substitute_empty_features(diffusion, conditioning, device, dtype)
cond_inputs = diffusion.get_conditioning_inputs(conditioning)
# Generate noise from seed (MPS doesn't support torch.Generator)
gen_device = "cpu" if device.type == "mps" else device
generator = torch.Generator(device=gen_device).manual_seed(seed)
noise = torch.randn(
[1, IO_CHANNELS, latent_length],
generator=generator,
device=gen_device,
).to(device=device, dtype=dtype)
pbar = comfy.utils.ProgressBar(steps)
from prismaudio_core.inference.sampling import sample_discrete_euler
def on_step(info):
pbar.update(1)
fakes = sample_discrete_euler(
diffusion.model,
noise,
steps,
callback=on_step,
**cond_inputs,
cfg_scale=cfg_scale,
batch_cfg=True,
)
fakes_f = fakes.float()
print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True)
if strategy == "offload_to_cpu":
diffusion.model.to(get_offload_device())
diffusion.conditioner.to(get_offload_device())
soft_empty_cache()
diffusion.pretransform.to(device)
# VAE decode in fp32 (snake activations overflow in fp16)
with torch.amp.autocast(device_type=device.type, enabled=False):
audio = diffusion.pretransform.decode(fakes_f)
if strategy == "offload_to_cpu":
diffusion.pretransform.to(get_offload_device())
soft_empty_cache()
# Peak normalize then clamp
audio = audio.float()
pre_norm_std = audio.std().item()
pre_norm_peak = audio.abs().max().item()
peak = audio.abs().max().clamp(min=1e-8)
audio = (audio / peak).clamp(-1, 1)
print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True)
print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True)
return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},)
# T5-Gemma encoder singleton
_t5_model = None
_t5_tokenizer = None
def _encode_text_t5(text, device, dtype):
"""Encode text using T5-Gemma.
Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference
FeaturesUtils.encode_t5_text() implementation.
No truncation applied (matching reference behavior).
"""
global _t5_model, _t5_tokenizer
if _t5_model is None:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id = "google/t5gemma-l-l-ul2-it"
token = resolve_hf_token()
print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}")
_t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
_t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder()
_t5_model.eval()
_t5_model.to(device, dtype=dtype)
tokens = _t5_tokenizer(
text,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
outputs = _t5_model(**tokens)
# Move T5 off GPU after encoding to save VRAM
_t5_model.to("cpu")
soft_empty_cache()
return outputs.last_hidden_state.squeeze(0) # [seq_len, dim]
+4 -47
View File
@@ -1,21 +1,7 @@
import os
import torch
import folder_paths
import comfy.model_management as mm
PRISMAUDIO_CATEGORY = "PrismAudio"
SAMPLE_RATE = 44100
DOWNSAMPLING_RATIO = 2048
IO_CHANNELS = 64
def get_prismaudio_model_dir():
model_dir = os.path.join(folder_paths.models_dir, "prismaudio")
os.makedirs(model_dir, exist_ok=True)
return model_dir
def register_model_folder():
model_dir = get_prismaudio_model_dir()
folder_paths.add_model_folder_path("prismaudio", model_dir)
SELVA_CATEGORY = "SelVA"
def get_device():
return mm.get_torch_device()
@@ -23,42 +9,13 @@ def get_device():
def get_offload_device():
return mm.unet_offload_device()
def get_free_memory(device=None):
if device is None:
device = get_device()
return mm.get_free_memory(device)
def soft_empty_cache():
mm.soft_empty_cache()
def determine_precision(preference, device):
if preference != "auto":
return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference]
if device.type == "cpu":
return torch.float32
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def determine_offload_strategy(preference):
if preference != "auto":
return preference
free_mem = get_free_memory()
gb = free_mem / (1024 ** 3)
if gb >= 24:
free_mem = mm.get_free_memory(get_device())
if free_mem / (1024 ** 3) >= 16:
return "keep_in_vram"
else:
return "offload_to_cpu"
def try_import_flash_attn():
try:
import flash_attn
return flash_attn
except ImportError:
return None
def resolve_hf_token():
env_token = os.environ.get("HF_TOKEN")
if env_token:
return env_token
return None
return "offload_to_cpu"
-5
View File
@@ -1,5 +0,0 @@
"""
PrismAudio core inference modules.
Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch).
Only inference-critical code — no training, no JAX/TF dependencies.
"""
-141
View File
@@ -1,141 +0,0 @@
{
"model_type": "diffusion_cond",
"sample_size": 397312,
"sample_rate": 44100,
"audio_channels": 2,
"model": {
"pretransform": {
"type": "autoencoder",
"iterate_batch": true,
"config": {
"encoder": {
"type": "oobleck",
"config": {
"in_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 128,
"use_snake": true
}
},
"decoder": {
"type": "oobleck",
"config": {
"out_channels": 2,
"channels": 128,
"c_mults": [1, 2, 4, 8, 16],
"strides": [2, 4, 4, 8, 8],
"latent_dim": 64,
"use_snake": true,
"final_tanh": false
}
},
"bottleneck": {
"type": "vae"
},
"latent_dim": 64,
"downsampling_ratio": 2048,
"io_channels": 2
}
},
"conditioning": {
"configs": [
{
"id": "video_features",
"type": "cond_mlp",
"config": {
"dim": 1024,
"output_dim": 1024
}
},
{
"id": "text_features",
"type": "cond_mlp",
"config": {
"dim": 1024,
"output_dim": 1024
}
},
{
"id": "sync_features",
"type": "sync_mlp",
"config": {
"dim": 768,
"output_dim": 1024
}
}
],
"cond_dim": 768
},
"diffusion": {
"cross_attention_cond_ids": ["video_features","text_features"],
"add_cond_ids": ["video_features"],
"sync_cond_ids": ["sync_features"],
"type": "dit",
"diffusion_objective": "rectified_flow",
"config": {
"io_channels": 64,
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"cond_token_dim": 1024,
"add_token_dim": 1024,
"sync_token_dim": 1024,
"project_cond_tokens": false,
"transformer_type": "continuous_transformer",
"attn_kwargs":{
"qk_norm": "rns"
},
"use_gated": true,
"use_sync_gated": true
}
},
"io_channels": 64
},
"training": {
"use_ema": true,
"log_loss_info": false,
"cfg_dropout_prob": 0.1,
"pre_encoded": true,
"timestep_sampler": "trunc_logit_normal",
"optimizer_configs": {
"diffusion": {
"optimizer": {
"type": "AdamW",
"config": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"weight_decay": 1e-3
}
},
"scheduler": {
"type": "InverseLR",
"config": {
"inv_gamma": 100000,
"power": 0.5,
"warmup": 0.99
}
}
}
},
"demo": {
"demo_every": 5000,
"demo_steps": 24,
"num_demos": 10,
"demo_cond": [
"dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
"dataset/videoprism/test/bmKtI808DsU_000009.npz",
"dataset/videoprism/test/VC0c22cJTbM_000424.npz",
"dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
"dataset/videoprism/test/WatvT8A8iug_000100.npz",
"dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
"dataset/videoprism/test/3-PFuDkTM48_000080.npz",
"dataset/videoprism/test/luSAuu-BoPs_000232.npz",
"dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
"dataset/videoprism/test/_0m_YMpQayA_000168.npz"
],
"demo_cfg_scales": [5]
}
}
}
-413
View File
@@ -1,413 +0,0 @@
"""
Model factory functions for PrismAudio inference.
Extracted from:
- PrismAudio/models/factory.py
- PrismAudio/models/autoencoders.py (create_autoencoder_from_config)
- PrismAudio/models/diffusion.py (create_diffusion_cond_from_config)
- PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config)
Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch)
Only inference-critical factory functions are retained.
"""
import json
import typing as tp
from typing import Dict, Any
import numpy as np
def create_model_from_config(model_config):
model_type = model_config.get('model_type', None)
assert model_type is not None, 'model_type must be specified in model config'
if model_type == 'autoencoder':
return create_autoencoder_from_config(model_config)
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
return create_diffusion_cond_from_config(model_config)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
def create_pretransform_from_config(pretransform_config, sample_rate):
pretransform_type = pretransform_config.get('type', None)
assert pretransform_type is not None, 'type must be specified in pretransform config'
if pretransform_type == 'autoencoder':
from prismaudio_core.models.pretransforms import AutoencoderPretransform
# Create fake top-level config to pass sample rate to autoencoder constructor
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
autoencoder = create_autoencoder_from_config(autoencoder_config)
scale = pretransform_config.get("scale", 1.0)
model_half = pretransform_config.get("model_half", False)
iterate_batch = pretransform_config.get("iterate_batch", False)
chunked = pretransform_config.get("chunked", False)
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
elif pretransform_type == 'wavelet':
raise NotImplementedError("wavelet pretransform type is not supported")
elif pretransform_type == 'pqmf':
from prismaudio_core.models.pretransforms import PQMFPretransform
pqmf_config = pretransform_config["config"]
pretransform = PQMFPretransform(**pqmf_config)
elif pretransform_type == 'dac_pretrained':
from prismaudio_core.models.pretransforms import PretrainedDACPretransform
pretrained_dac_config = pretransform_config["config"]
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
elif pretransform_type == "audiocraft_pretrained":
from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform
audiocraft_config = pretransform_config["config"]
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
else:
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
enable_grad = pretransform_config.get('enable_grad', False)
pretransform.enable_grad = enable_grad
pretransform.eval().requires_grad_(pretransform.enable_grad)
return pretransform
def create_bottleneck_from_config(bottleneck_config):
bottleneck_type = bottleneck_config.get('type', None)
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
if bottleneck_type == 'tanh':
from prismaudio_core.models.bottleneck import TanhBottleneck
bottleneck = TanhBottleneck()
elif bottleneck_type == 'vae':
from prismaudio_core.models.bottleneck import VAEBottleneck
bottleneck = VAEBottleneck()
elif bottleneck_type == 'rvq':
from prismaudio_core.models.bottleneck import RVQBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQBottleneck(**quantizer_params)
elif bottleneck_type == "dac_rvq":
from prismaudio_core.models.bottleneck import DACRVQBottleneck
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'rvq_vae':
from prismaudio_core.models.bottleneck import RVQVAEBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQVAEBottleneck(**quantizer_params)
elif bottleneck_type == 'dac_rvq_vae':
from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'l2_norm':
from prismaudio_core.models.bottleneck import L2Bottleneck
bottleneck = L2Bottleneck()
elif bottleneck_type == "wasserstein":
from prismaudio_core.models.bottleneck import WassersteinBottleneck
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
elif bottleneck_type == "fsq":
from prismaudio_core.models.bottleneck import FSQBottleneck
bottleneck = FSQBottleneck(**bottleneck_config["config"])
else:
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
requires_grad = bottleneck_config.get('requires_grad', True)
if not requires_grad:
for param in bottleneck.parameters():
param.requires_grad = False
return bottleneck
def create_autoencoder_from_config(config: Dict[str, Any]):
"""Create an AudioAutoencoder from a config dictionary.
Originally in PrismAudio/models/autoencoders.py.
"""
from prismaudio_core.models.autoencoders import (
AudioAutoencoder,
create_encoder_from_config,
create_decoder_from_config,
)
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]):
"""Create a MultiConditioner from a conditioning config dictionary.
Originally in PrismAudio/models/conditioners.py.
"""
from prismaudio_core.models.conditioners import (
MultiConditioner,
T5Conditioner,
CLAPTextConditioner,
CLIPTextConditioner,
MetaCLIPTextConditioner,
CLAPAudioConditioner,
Cond_MLP,
Global_MLP,
Sync_MLP,
Cond_MLP_1,
Cond_ConvMLP,
Cond_MLP_Global,
Cond_MLP_Global_1,
Cond_MLP_Global_2,
Video_Global,
Video_Sync,
Text_Linear,
CLIPConditioner,
IntConditioner,
NumberConditioner,
PhonemeConditioner,
TokenizerLUTConditioner,
PretransformConditioner,
mm_unchang,
)
from prismaudio_core.models.utils import load_ckpt_state_dict
conditioners = {}
cond_dim = config["cond_dim"]
default_keys = config.get("default_keys", {})
for conditioner_info in config["configs"]:
id = conditioner_info["id"]
conditioner_type = conditioner_info["type"]
conditioner_config = {"output_dim": cond_dim}
conditioner_config.update(conditioner_info["config"])
if conditioner_type == "t5":
conditioners[id] = T5Conditioner(**conditioner_config)
elif conditioner_type == "clap_text":
conditioners[id] = CLAPTextConditioner(**conditioner_config)
elif conditioner_type == "clip_text":
conditioners[id] = CLIPTextConditioner(**conditioner_config)
elif conditioner_type == "metaclip_text":
conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
elif conditioner_type == "clap_audio":
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
elif conditioner_type == "cond_mlp":
conditioners[id] = Cond_MLP(**conditioner_config)
elif conditioner_type == "global_mlp":
conditioners[id] = Global_MLP(**conditioner_config)
elif conditioner_type == "sync_mlp":
conditioners[id] = Sync_MLP(**conditioner_config)
elif conditioner_type == "cond_mlp_1":
conditioners[id] = Cond_MLP_1(**conditioner_config)
elif conditioner_type == "cond_convmlp":
conditioners[id] = Cond_ConvMLP(**conditioner_config)
elif conditioner_type == "cond_mlp_global":
conditioners[id] = Cond_MLP_Global(**conditioner_config)
elif conditioner_type == "cond_mlp_global_1":
conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
elif conditioner_type == "cond_mlp_global_2":
conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
elif conditioner_type == "video_global":
conditioners[id] = Video_Global(**conditioner_config)
elif conditioner_type == "video_sync":
conditioners[id] = Video_Sync(**conditioner_config)
elif conditioner_type == "text_linear":
conditioners[id] = Text_Linear(**conditioner_config)
elif conditioner_type == "video_clip":
conditioners[id] = CLIPConditioner(**conditioner_config)
elif conditioner_type == "int":
conditioners[id] = IntConditioner(**conditioner_config)
elif conditioner_type == "number":
conditioners[id] = NumberConditioner(**conditioner_config)
elif conditioner_type == "phoneme":
conditioners[id] = PhonemeConditioner(**conditioner_config)
elif conditioner_type == "lut":
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
elif conditioner_type == "pretransform":
sample_rate = conditioner_config.pop("sample_rate", None)
assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
elif conditioner_type == "mm_unchang":
conditioners[id] = mm_unchang(**conditioner_config)
else:
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
return MultiConditioner(conditioners, default_keys=default_keys)
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
"""Create a ConditionedDiffusionModelWrapper from a config dictionary.
Originally in PrismAudio/models/diffusion.py.
"""
from prismaudio_core.models.diffusion import (
ConditionedDiffusionModelWrapper,
MMConditionedDiffusionModelWrapper,
UNetCFG1DWrapper,
UNet1DCondWrapper,
DiTWrapper,
)
model_config = config["model"]
model_type = config["model_type"]
diffusion_config = model_config.get('diffusion', None)
assert diffusion_config is not None, "Must specify diffusion config"
diffusion_model_type = diffusion_config.get('type', None)
assert diffusion_model_type is not None, "Must specify diffusion model type"
diffusion_model_config = diffusion_config.get('config', None)
assert diffusion_model_config is not None, "Must specify diffusion model config"
if diffusion_model_type == 'adp_cfg_1d':
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
elif diffusion_model_type == 'adp_1d':
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit':
raise NotImplementedError("mmdit diffusion model type is not supported")
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
sample_rate = config.get('sample_rate', None)
assert sample_rate is not None, "Must specify sample_rate in config"
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
conditioning_config = model_config.get('conditioning', None)
conditioner = None
if conditioning_config is not None:
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
add_cond_ids = diffusion_config.get('add_cond_ids', [])
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
zero_init = diffusion_config.get('zero_init', False)
pretransform = model_config.get("pretransform", None)
if pretransform is not None:
pretransform = create_pretransform_from_config(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
min_input_length *= np.prod(diffusion_model_config["factors"])
elif diffusion_model_type == "dit":
min_input_length *= diffusion_model.model.patch_size
# Get the proper wrapper class
extra_kwargs = {}
if model_type == "mm_diffusion_cond":
wrapper_fn = MMConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior":
raise NotImplementedError("diffusion_prior model type is not supported")
return wrapper_fn(
diffusion_model,
conditioner,
min_input_length=min_input_length,
sample_rate=sample_rate,
cross_attn_cond_ids=cross_attention_ids,
global_cond_ids=global_cond_ids,
input_concat_ids=input_concat_ids,
prepend_cond_ids=prepend_cond_ids,
add_cond_ids=add_cond_ids,
sync_cond_ids=sync_cond_ids,
pretransform=pretransform,
io_channels=io_channels,
zero_init=zero_init,
**extra_kwargs
)
-4
View File
@@ -1,4 +0,0 @@
from .sampling import sample_discrete_euler
from .utils import set_audio_channels, prepare_audio
__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"]
-29
View File
@@ -1,29 +0,0 @@
import torch
@torch.no_grad()
def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args):
"""Discrete Euler sampler for rectified flow, with optional callback.
Modified from PrismAudio to add callback parameter for ComfyUI progress reporting.
Original uses tqdm internally.
Args:
model: The diffusion model (DiTWrapper)
x: Initial noise tensor [B, C, T]
steps: Number of sampling steps
sigma_max: Maximum sigma (default 1.0 for rectified flow)
callback: Optional callable({"i": step, "x": current_x}) for progress
**extra_args: Passed to model() — includes cross_attn_cond, add_cond,
sync_cond, cfg_scale, batch_cfg, etc.
"""
t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype)
for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])):
dt = t_next - t_curr
t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device)
x = x + dt * model(x, t_curr_tensor, **extra_args)
if callback is not None:
callback({"i": i, "x": x})
return x
-62
View File
@@ -1,62 +0,0 @@
import torch
import torch.nn.functional as F
from torchaudio import transforms as T
def set_audio_channels(audio, target_channels):
"""Convert audio tensor to target number of channels.
Args:
audio: Audio tensor of shape [B, C, T]
target_channels: Desired number of channels (1 for mono, 2 for stereo)
Returns:
Audio tensor with the target number of channels.
"""
if target_channels == 1:
# Convert to mono
audio = audio.mean(1, keepdim=True)
elif target_channels == 2:
# Convert to stereo
if audio.shape[1] == 1:
audio = audio.repeat(1, 2, 1)
elif audio.shape[1] > 2:
audio = audio[:, :2, :]
return audio
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Resample, pad/trim, and convert channels of an audio tensor.
Args:
audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T])
in_sr: Input sample rate
target_sr: Target sample rate
target_length: Target length in samples (padded or cropped)
target_channels: Target number of channels
device: Torch device to place the audio on
Returns:
Audio tensor of shape [B, target_channels, target_length] on device.
"""
audio = audio.to(device)
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
# Add batch dimension
if audio.dim() == 1:
audio = audio.unsqueeze(0).unsqueeze(0)
elif audio.dim() == 2:
audio = audio.unsqueeze(0)
# Pad or crop to target_length
if audio.shape[-1] < target_length:
audio = F.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[:, :, :target_length]
audio = set_audio_channels(audio, target_channels)
return audio
-9
View File
@@ -1,9 +0,0 @@
"""
PrismAudio model modules for inference.
Re-exports create_model_from_config from the factory module.
"""
from prismaudio_core.factory import create_model_from_config
__all__ = ["create_model_from_config"]
File diff suppressed because it is too large Load Diff
-821
View File
@@ -1,821 +0,0 @@
import torch
import math
import numpy as np
from torch import nn
from torch.nn import functional as F
from torchaudio import transforms as T
from alias_free_torch import Activation1d
from dac.nn.layers import WNConv1d, WNConvTranspose1d
from typing import Literal, Dict, Any
from .blocks import SnakeBeta
from .bottleneck import Bottleneck, DiscreteBottleneck
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
from .pretransforms import Pretransform
from .utils import checkpoint
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
import torchaudio.transforms as T
import torch
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
if audio.shape[0] > target_channels:
audio = audio[:target_channels]
elif audio.shape[0] < target_channels:
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
if audio.shape[-1] < target_length:
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[..., :target_length]
return audio.unsqueeze(0)
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config(pretransform, sample_rate)
def _lazy_create_bottleneck_from_config(bottleneck):
from prismaudio_core.factory import create_bottleneck_from_config
return create_bottleneck_from_config(bottleneck)
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class DACEncoderWrapper(nn.Module):
def __init__(self, in_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Encoder as DACEncoder
latent_dim = kwargs.pop("latent_dim", None)
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
self.latent_dim = latent_dim
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
if in_channels != 1:
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
def forward(self, x):
x = self.encoder(x)
x = self.proj_out(x)
return x
class DACDecoderWrapper(nn.Module):
def __init__(self, latent_dim, out_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Decoder as DACDecoder
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
self.latent_dim = latent_dim
def forward(self, x):
return self.decoder(x)
class AudioAutoencoder(nn.Module):
def __init__(
self,
encoder,
decoder,
latent_dim,
downsampling_ratio,
sample_rate,
io_channels=2,
bottleneck: Bottleneck = None,
pretransform: Pretransform = None,
in_channels = None,
out_channels = None,
soft_clip = False
):
super().__init__()
self.downsampling_ratio = downsampling_ratio
self.sample_rate = sample_rate
self.latent_dim = latent_dim
self.io_channels = io_channels
self.in_channels = io_channels
self.out_channels = io_channels
self.min_length = self.downsampling_ratio
if in_channels is not None:
self.in_channels = in_channels
if out_channels is not None:
self.out_channels = out_channels
self.bottleneck = bottleneck
self.encoder = encoder
self.decoder = decoder
self.pretransform = pretransform
self.soft_clip = soft_clip
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
info = {}
if self.pretransform is not None and not skip_pretransform:
if self.pretransform.enable_grad:
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
else:
with torch.no_grad():
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
if self.encoder is not None:
if iterate_batch:
latents = []
for i in range(audio.shape[0]):
latents.append(self.encoder(audio[i:i+1]))
latents = torch.cat(latents, dim=0)
else:
latents = self.encoder(audio)
else:
latents = audio
if self.bottleneck is not None:
# TODO: Add iterate batch logic, needs to merge the info dicts
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
info.update(bottleneck_info)
if return_info:
return latents, info
return latents
def decode(self, latents, iterate_batch=False, **kwargs):
if self.bottleneck is not None:
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.bottleneck.decode(latents[i:i+1]))
latents = torch.cat(decoded, dim=0)
else:
latents = self.bottleneck.decode(latents)
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.decoder(latents[i:i+1]))
decoded = torch.cat(decoded, dim=0)
else:
decoded = self.decoder(latents, **kwargs)
if self.pretransform is not None:
if self.pretransform.enable_grad:
if iterate_batch:
decodeds = []
for i in range(decoded.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
if iterate_batch:
decodeds = []
for i in range(latents.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
if self.soft_clip:
decoded = torch.tanh(decoded)
return decoded
def decode_tokens(self, tokens, **kwargs):
'''
Decode discrete tokens to audio
Only works with discrete autoencoders
'''
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
return self.decode(latents, **kwargs)
def preprocess_audio_for_encoder(self, audio, in_sr):
'''
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
If the model is mono, stereo audio will be converted to mono.
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
Audio will be resampled to the model's sample rate.
The output will have batch size 1 and be shape (1 x Channels x Length)
'''
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
'''
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
The audio in that list can be of different lengths and channels.
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
All audio will be resampled to the model's sample rate.
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
If the model is mono, all audio will be converted to mono.
The output will be a tensor of shape (Batch x Channels x Length)
'''
batch_size = len(audio_list)
if isinstance(in_sr_list, int):
in_sr_list = [in_sr_list]*batch_size
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
new_audio = []
max_length = 0
# resample & find the max length
for i in range(batch_size):
audio = audio_list[i]
in_sr = in_sr_list[i]
if len(audio.shape) == 3 and audio.shape[0] == 1:
# batchsize 1 was given by accident. Just squeeze it.
audio = audio.squeeze(0)
elif len(audio.shape) == 1:
# Mono signal, channel dimension is missing, unsqueeze it in
audio = audio.unsqueeze(0)
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
# Resample audio
if in_sr != self.sample_rate:
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
audio = resample_tf(audio)
new_audio.append(audio)
if audio.shape[-1] > max_length:
max_length = audio.shape[-1]
# Pad every audio to the same length, multiple of model's downsampling ratio
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
for i in range(batch_size):
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
# convert to tensor
return torch.stack(new_audio)
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
Overlap and chunk_size params are both measured in number of latents (not audio samples)
# and therefore you likely could use the same values with decode_audio.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Encode the entire audio in parallel
return self.encode(audio, **kwargs)
else:
# CHUNKED ENCODING
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
samples_per_latent = self.downsampling_ratio
total_size = audio.shape[2] # in samples
batch_size = audio.shape[0]
chunk_size *= samples_per_latent # converting metric in latents to samples
overlap *= samples_per_latent # converting metric in latents to samples
hop_size = chunk_size - overlap
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = audio[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = audio[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# Note: y_size might be a different value from the latent length used in diffusion training
# because we can encode audio of varying lengths
# However, the audio should've been padded to a multiple of samples_per_latent by now.
y_size = total_size // samples_per_latent
# Create an empty latent, we will populate it with chunks as we encode them
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# encode the chunk
y_chunk = self.encode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size // samples_per_latent
t_end = t_start + chunk_size // samples_per_latent
# remove the edges of the overlaps
ol = overlap//samples_per_latent//2
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Decode latents to audio.
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Decode the entire latent in parallel
return self.decode(latents, **kwargs)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = latents[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = self.downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# decode the chunk
y_chunk = self.decode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap//2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
class DiffusionAutoencoder(AudioAutoencoder):
def __init__(
self,
diffusion: ConditionedDiffusionModel,
diffusion_downsampling_ratio,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.diffusion = diffusion
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
if self.encoder is not None:
# Shrink the initial encoder parameters to avoid saturated latents
with torch.no_grad():
for param in self.encoder.parameters():
param *= 0.5
def decode(self, latents, steps=100):
upsampled_length = latents.shape[2] * self.downsampling_ratio
if self.bottleneck is not None:
latents = self.bottleneck.decode(latents)
if self.decoder is not None:
latents = self.decoder(latents)
# Upsample latents to match diffusion length
if latents.shape[2] != upsampled_length:
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
from prismaudio_core.inference.sampling import sample
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
if self.pretransform is not None:
if self.pretransform.enable_grad:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
decoded = self.pretransform.decode(decoded)
return decoded
# AE factories
def create_encoder_from_config(encoder_config: Dict[str, Any]):
encoder_type = encoder_config.get("type", None)
assert encoder_type is not None, "Encoder type must be specified"
if encoder_type == "oobleck":
encoder = OobleckEncoder(
**encoder_config["config"]
)
elif encoder_type == "seanet":
from encodec.modules import SEANetEncoder
seanet_encoder_config = encoder_config["config"]
#SEANet encoder expects strides in reverse order
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
encoder = SEANetEncoder(
**seanet_encoder_config
)
elif encoder_type == "dac":
dac_config = encoder_config["config"]
encoder = DACEncoderWrapper(**dac_config)
elif encoder_type == "local_attn":
from .local_attention import TransformerEncoder1D
local_attn_config = encoder_config["config"]
encoder = TransformerEncoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown encoder type {encoder_type}")
requires_grad = encoder_config.get("requires_grad", True)
if not requires_grad:
for param in encoder.parameters():
param.requires_grad = False
return encoder
def create_decoder_from_config(decoder_config: Dict[str, Any]):
decoder_type = decoder_config.get("type", None)
assert decoder_type is not None, "Decoder type must be specified"
if decoder_type == "oobleck":
decoder = OobleckDecoder(
**decoder_config["config"]
)
elif decoder_type == "seanet":
from encodec.modules import SEANetDecoder
decoder = SEANetDecoder(
**decoder_config["config"]
)
elif decoder_type == "dac":
dac_config = decoder_config["config"]
decoder = DACDecoderWrapper(**dac_config)
elif decoder_type == "local_attn":
from .local_attention import TransformerDecoder1D
local_attn_config = decoder_config["config"]
decoder = TransformerDecoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown decoder type {decoder_type}")
requires_grad = decoder_config.get("requires_grad", True)
if not requires_grad:
for param in decoder.parameters():
param.requires_grad = False
return decoder
def create_autoencoder_from_config(config: Dict[str, Any]):
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_diffAE_from_config(config: Dict[str, Any]):
diffae_config = config["model"]
if "encoder" in diffae_config:
encoder = create_encoder_from_config(diffae_config["encoder"])
else:
encoder = None
if "decoder" in diffae_config:
decoder = create_decoder_from_config(diffae_config["decoder"])
else:
decoder = None
diffusion_model_type = diffae_config["diffusion"]["type"]
if diffusion_model_type == "DAU1d":
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "adp_1d":
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "dit":
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
latent_dim = diffae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = diffae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
bottleneck = diffae_config.get("bottleneck", None)
pretransform = diffae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
diffusion_downsampling_ratio = None
if diffusion_model_type == "DAU1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
elif diffusion_model_type == "adp_1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
elif diffusion_model_type == "dit":
diffusion_downsampling_ratio = 1
return DiffusionAutoencoder(
encoder=encoder,
decoder=decoder,
diffusion=diffusion,
io_channels=io_channels,
sample_rate=sample_rate,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
bottleneck=bottleneck,
pretransform=pretransform
)
-331
View File
@@ -1,331 +0,0 @@
from functools import reduce
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.backends.cuda import sdp_kernel
from packaging import version
from dac.nn.layers import Snake1d
class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(ResidualBlock):
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
super().__init__([
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_mid),
Snake1d(c_mid) if use_snake else nn.GELU(),
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
], skip)
class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm = nn.GroupNorm(1, c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
if self.use_flash:
with sdp_kernel(*self.sdp_kernel_config):
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
else:
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)
def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
class Downsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv1d(x, weight, stride=2)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
class Upsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
def Downsample1d_2(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor,
padding=factor * (kernel_multiplier // 2),
)
def Upsample1d_2(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
)
else:
return nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor,
padding=factor // 2 + factor % 2,
output_padding=factor % 2,
)
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
def normalize(x, eps=1e-4):
dim = list(range(1, x.ndim))
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
alpha = np.sqrt(n.numel() / x.numel())
return x / torch.add(eps, n, alpha=alpha)
class ForcedWNConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
def forward(self, x):
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(self.weight))
fan_in = self.weight[0].numel()
w = normalize(self.weight) / math.sqrt(fan_in)
return F.conv1d(x, w, padding='same')
# Kernels
use_compile = True
def compile(function, *args, **kwargs):
if not use_compile:
return function
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
@compile
def linear_geglu(x, weight, bias=None):
x = x @ weight.mT
if bias is not None:
x = x + bias
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
@compile
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
# Layers
class LinearGEGLU(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features * 2, bias=bias)
self.out_features = out_features
def forward(self, x):
return linear_geglu(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, shape, fix_scale = False, eps=1e-6):
super().__init__()
self.eps = eps
if fix_scale:
self.register_buffer("scale", torch.ones(shape))
else:
self.scale = nn.Parameter(torch.ones(shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# try:
# snake_beta = torch.compile(snake_beta)
# except RuntimeError:
# pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
-355
View File
@@ -1,355 +0,0 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from vector_quantize_pytorch import ResidualVQ, FSQ
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
class Bottleneck(nn.Module):
def __init__(self, is_discrete: bool = False):
super().__init__()
self.is_discrete = is_discrete
def encode(self, x, return_info=False, **kwargs):
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
class DiscreteBottleneck(Bottleneck):
def __init__(self, num_quantizers, codebook_size, tokens_id):
super().__init__(is_discrete=True)
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.tokens_id = tokens_id
def decode_tokens(self, codes, **kwargs):
raise NotImplementedError
class TanhBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
self.tanh = nn.Tanh()
def encode(self, x, return_info=False):
info = {}
x = torch.tanh(x)
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def compute_mean_kernel(x, y):
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
return torch.exp(-kernel_input).mean()
def compute_mmd(latents):
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
noise = torch.randn_like(latents_reshaped)
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
noise_kernel = compute_mean_kernel(noise, noise)
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
return mmd.mean()
class WassersteinBottleneck(Bottleneck):
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
super().__init__(is_discrete=False)
self.noise_augment_dim = noise_augment_dim
self.bypass_mmd = bypass_mmd
def encode(self, x, return_info=False):
info = {}
if self.training and return_info:
if self.bypass_mmd:
mmd = torch.tensor(0.0)
else:
mmd = compute_mmd(x)
info["mmd"] = mmd
if return_info:
return x, info
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
class L2Bottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False):
info = {}
x = F.normalize(x, dim=1)
if return_info:
return x, info
else:
return x
def decode(self, x):
return F.normalize(x, dim=1)
class RVQBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False, **kwargs):
info = {}
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class RVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False):
info = {}
x, kl = vae_sample(*x.chunk(2, dim=1))
info["kl"] = kl
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class DACRVQBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
self.noise_augment_dim = noise_augment_dim
def encode(self, x, return_info=False, **kwargs):
info = {}
info["pre_quantizer"] = x
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class DACRVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
def encode(self, x, return_info=False, n_quantizers: int = None):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["pre_quantizer"] = x
info["kl"] = kl
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class FSQBottleneck(DiscreteBottleneck):
def __init__(self, noise_augment_dim=0, **kwargs):
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
self.noise_augment_dim = noise_augment_dim
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
def encode(self, x, return_info=False):
info = {}
orig_dtype = x.dtype
x = x.float()
x = rearrange(x, "b c n -> b n c")
x, indices = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
x = x.to(orig_dtype)
# Reorder indices to match the expected format
indices = rearrange(indices, "b n q -> b q n")
info["quantizer_indices"] = indices
if return_info:
return x, info
else:
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, tokens, **kwargs):
latents = self.quantizer.indices_to_codes(tokens)
return self.decode(latents, **kwargs)
File diff suppressed because it is too large Load Diff
-884
View File
@@ -1,884 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
import numpy as np
import typing as tp
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
from .conditioners import MultiConditioner
from .dit import DiffusionTransformer
from .pretransforms import Pretransform
from .adp import UNetCFG1d, UNet1d
# Lazy imports for factory functions to avoid circular imports
def _get_create_pretransform_from_config():
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config
def _get_create_multi_conditioner_from_conditioning_config():
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
return create_multi_conditioner_from_conditioning_config
class DiffusionModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, t, **kwargs):
raise NotImplementedError()
class DiffusionModelWrapper(nn.Module):
def __init__(
self,
model: DiffusionModel,
io_channels,
sample_size,
sample_rate,
min_input_length,
pretransform: tp.Optional[Pretransform] = None,
):
super().__init__()
self.io_channels = io_channels
self.sample_size = sample_size
self.sample_rate = sample_rate
self.min_input_length = min_input_length
self.model = model
if pretransform is not None:
self.pretransform = pretransform
else:
self.pretransform = None
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class ConditionedDiffusionModel(nn.Module):
def __init__(self,
*args,
supports_cross_attention: bool = False,
supports_input_concat: bool = False,
supports_global_cond: bool = False,
supports_prepend_cond: bool = False,
**kwargs):
super().__init__(*args, **kwargs)
self.supports_cross_attention = supports_cross_attention
self.supports_input_concat = supports_input_concat
self.supports_global_cond = supports_global_cond
self.supports_prepend_cond = supports_prepend_cond
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
cross_attn_cond: torch.Tensor = None,
cross_attn_mask: torch.Tensor = None,
input_concat_cond: torch.Tensor = None,
global_embed: torch.Tensor = None,
prepend_cond: torch.Tensor = None,
prepend_cond_mask: torch.Tensor = None,
cfg_scale: float = 1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
**kwargs):
raise NotImplementedError()
class ConditionedDiffusionModelWrapper(nn.Module):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model: ConditionedDiffusionModel,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
zero_init: bool = False,
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
sync_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.sync_cond_ids = sync_cond_ids
self.min_input_length = min_input_length
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if zero_init is True:
self.conditioner.apply(_basic_init)
self.model.model.initialize_weights()
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
sync_input = None
if len(self.cross_attn_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
cross_attention_input = []
cross_attention_masks = []
for key in self.cross_attn_cond_ids:
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
# Add sequence dimension if it's not there
if len(cross_attn_in.shape) == 2:
cross_attn_in = cross_attn_in.unsqueeze(1)
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
if len(self.add_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
add_input = []
for key in self.add_cond_ids:
add_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(add_in.shape) == 2:
add_in = add_in.unsqueeze(1)
# add_in = add_in.transpose(1,2)
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
# add_in = add_in.transpose(1,2)
add_input.append(add_in)
add_input = torch.cat(add_input, dim=2)
if len(self.sync_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
sync_input = []
for key in self.sync_cond_ids:
sync_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(sync_in.shape) == 2:
sync_in = sync_in.unsqueeze(1)
sync_input.append(sync_in)
sync_input = torch.cat(sync_input, dim=2)
if len(self.global_cond_ids) > 0:
# Concatenate all global conditioning inputs over the channel dimension
# Assumes that the global conditioning inputs are of shape (batch, channels)
global_conds = []
for key in self.global_cond_ids:
global_cond_input = conditioning_tensors[key][0]
if len(global_cond_input.shape) == 2:
global_cond_input = global_cond_input.unsqueeze(1)
global_conds.append(global_cond_input)
# # Concatenate over the channel dimension
# if global_conds[0].shape[-1] == 768:
# global_cond = torch.cat(global_conds, dim=-1)
# else:
# global_cond = sum(global_conds)
global_cond = sum(global_conds)
# global_cond = torch.cat(global_conds, dim=-1)
if len(global_cond.shape) == 3:
global_cond = global_cond.squeeze(1)
if len(self.input_concat_ids) > 0:
# Concatenate all input concat conditioning inputs over the channel dimension
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
if len(self.prepend_cond_ids) > 0:
# Concatenate all prepend conditioning inputs over the sequence dimension
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
prepend_conds = []
prepend_cond_masks = []
for key in self.prepend_cond_ids:
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
if len(prepend_cond_input.shape) == 2:
prepend_cond_input = prepend_cond_input.unsqueeze(1)
prepend_conds.append(prepend_cond_input)
prepend_cond_masks.append(prepend_cond_mask)
prepend_cond = torch.cat(prepend_conds, dim=1)
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
if negative:
return {
"negative_cross_attn_cond": cross_attention_input,
"negative_cross_attn_mask": cross_attention_masks,
"negative_global_cond": global_cond,
"negative_input_concat_cond": input_concat_cond
}
else:
return {
"cross_attn_cond": cross_attention_input,
"cross_attn_mask": cross_attention_masks,
"global_cond": global_cond,
"input_concat_cond": input_concat_cond,
"prepend_cond": prepend_cond,
"prepend_cond_mask": prepend_cond_mask,
"add_cond": add_input,
"sync_cond": sync_input
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class UNetCFG1DWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
self.model = UNetCFG1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
input_concat_cond=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
**kwargs):
channels_list = None
if input_concat_cond is not None:
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
embedding=cross_attn_cond,
embedding_mask=cross_attn_mask,
features=global_cond,
channels_list=channels_list,
embedding_scale=cfg_scale,
embedding_mask_proba=cfg_dropout_prob,
batch_cfg=batch_cfg,
rescale_cfg=rescale_cfg,
negative_embedding=negative_cross_attn_cond,
negative_embedding_mask=negative_cross_attn_mask,
**kwargs)
return outputs
class UNet1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
self.model = UNet1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
global_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
**kwargs):
channels_list = None
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
features=global_cond,
channels_list=channels_list,
**kwargs)
return outputs
class UNet1DUncondWrapper(DiffusionModel):
def __init__(
self,
in_channels,
*args,
**kwargs
):
super().__init__()
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
self.io_channels = in_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class DAU1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
self.model = DiffusionAttnUnet1D(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
**kwargs):
return self.model(x, t, cond = input_concat_cond)
class DiffusionAttnUnet1D(nn.Module):
def __init__(
self,
io_channels = 2,
depth=14,
n_attn_layers = 6,
channels = [128, 128, 256, 256] + [512] * 10,
cond_dim = 0,
cond_noise_aug = False,
kernel_size = 5,
learned_resample = False,
strides = [2] * 13,
conv_bias = True,
use_snake = False
):
super().__init__()
self.cond_noise_aug = cond_noise_aug
self.io_channels = io_channels
if self.cond_noise_aug:
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
self.timestep_embed = FourierFeatures(1, 16)
attn_layer = depth - n_attn_layers
strides = [1] + strides
block = nn.Identity()
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
for i in range(depth, 0, -1):
c = channels[i - 1]
stride = strides[i-1]
if stride > 2 and not learned_resample:
raise ValueError("Must have stride 2 without learned resampling")
if i > 1:
c_prev = channels[i - 2]
add_attn = i >= attn_layer and n_attn_layers > 0
block = SkipBlock(
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
conv_block(c_prev, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
block,
conv_block(c * 2 if i != depth else c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c_prev),
SelfAttention1d(c_prev, c_prev //
32) if add_attn else nn.Identity(),
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
)
else:
cond_embed_dim = 16 if not self.cond_noise_aug else 32
block = nn.Sequential(
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
conv_block(c, c, c),
conv_block(c, c, c),
block,
conv_block(c * 2, c, c),
conv_block(c, c, c),
conv_block(c, c, io_channels, is_last=True),
)
self.net = block
with torch.no_grad():
for param in self.net.parameters():
param *= 0.5
def forward(self, x, t, cond=None, cond_aug_scale=None):
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
inputs = [x, timestep_embed]
if cond is not None:
if cond.shape[2] != x.shape[2]:
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
if self.cond_noise_aug:
# Get a random number between 0 and 1, uniformly sampled
if cond_aug_scale is None:
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
else:
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
# Add noise to the conditioning signal
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
# Get embedding for noise cond level, reusing timestamp_embed
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
inputs.append(aug_level_embed)
inputs.append(cond)
outputs = self.net(torch.cat(inputs, dim=1))
return outputs
class DiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = DiffusionTransformer(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
negative_input_concat_cond=None,
global_cond=None,
negative_global_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_mask,
negative_cross_attn_cond=negative_cross_attn_cond,
negative_cross_attn_mask=negative_cross_attn_mask,
input_concat_cond=input_concat_cond,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
global_embed=global_cond,
**kwargs)
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
mm_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.min_input_length = min_input_length
self.mm_cond_ids = mm_cond_ids
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
inpaint_masked_input = None
t5_features = None
metaclip_global_text_features = None
clip_f = conditioning_tensors["metaclip_features"]
sync_f = conditioning_tensors["sync_features"]
text_f = conditioning_tensors["metaclip_text_features"]
if 'inpaint_masked_input' in conditioning_tensors.keys():
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
if 't5_features' in conditioning_tensors.keys():
t5_features = conditioning_tensors["t5_features"]
if 'metaclip_global_text_features' in conditioning_tensors.keys():
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
return {
"clip_f": clip_f,
"sync_f": sync_f,
"text_f": text_f,
"inpaint_masked_input": inpaint_masked_input,
"t5_features": t5_features,
"metaclip_global_text_features": metaclip_global_text_features
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class DiTUncondWrapper(DiffusionModel):
def __init__(
self,
io_channels,
*args,
**kwargs
):
super().__init__()
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
self.io_channels = io_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
model_type = diffusion_uncond_config.get('type', None)
diffusion_config = diffusion_uncond_config.get('config', {})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**diffusion_config
)
elif model_type == "adp_uncond_1d":
model = UNet1DUncondWrapper(
**diffusion_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**diffusion_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
diffusion_config = diffusion_uncond_config.get('diffusion', {})
model_type = diffusion_config.get('type', None)
model_config = diffusion_config.get("config",{})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**model_config
)
elif model_type == "adp_uncond_1d":
io_channels = model_config.get("io_channels", 64)
model = UNet1DUncondWrapper(
io_channels = io_channels,
**model_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**model_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
model_config = config["model"]
model_type = config["model_type"]
diffusion_config = model_config.get('diffusion', None)
assert diffusion_config is not None, "Must specify diffusion config"
diffusion_model_type = diffusion_config.get('type', None)
assert diffusion_model_type is not None, "Must specify diffusion model type"
diffusion_model_config = diffusion_config.get('config', None)
assert diffusion_model_config is not None, "Must specify diffusion model config"
if diffusion_model_type == 'adp_cfg_1d':
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
elif diffusion_model_type == 'adp_1d':
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
else:
raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}')
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
sample_rate = config.get('sample_rate', None)
assert sample_rate is not None, "Must specify sample_rate in config"
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
conditioning_config = model_config.get('conditioning', None)
conditioner = None
if conditioning_config is not None:
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
add_cond_ids = diffusion_config.get('add_cond_ids', [])
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
zero_init = diffusion_config.get('zero_init', False)
pretransform = model_config.get("pretransform", None)
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
min_input_length *= np.prod(diffusion_model_config["factors"])
elif diffusion_model_type == "dit":
min_input_length *= diffusion_model.model.patch_size
# Get the proper wrapper class
extra_kwargs = {}
if model_type == "mm_diffusion_cond":
wrapper_fn = MMConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids
elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return wrapper_fn(
diffusion_model,
conditioner,
min_input_length=min_input_length,
sample_rate=sample_rate,
cross_attn_cond_ids=cross_attention_ids,
global_cond_ids=global_cond_ids,
input_concat_ids=input_concat_ids,
prepend_cond_ids=prepend_cond_ids,
add_cond_ids=add_cond_ids,
sync_cond_ids=sync_cond_ids,
pretransform=pretransform,
io_channels=io_channels,
zero_init=zero_init,
**extra_kwargs
)
-539
View File
@@ -1,539 +0,0 @@
import typing as tp
import math
import torch
# from beartype.typing import Tuple
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from .utils import mask_from_frac_lengths, resample
class DiffusionTransformer(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
cond_ctx_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
add_token_dim=0,
sync_token_dim=0,
use_mlp=False,
use_zero_init=False,
**kwargs):
super().__init__()
self.cond_token_dim = cond_token_dim
# Timestep embeddings
timestep_features_dim = 256
# Timestep embeddings
self.timestep_cond_type = timestep_cond_type
self.timestep_features = FourierFeatures(1, timestep_features_dim)
if timestep_cond_type == "global":
timestep_embed_dim = embed_dim
elif timestep_cond_type == "input_concat":
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
input_concat_dim += timestep_embed_dim
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
self.use_mlp = use_mlp
if cond_token_dim > 0:
# Conditioning tokens
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
self.to_cond_embed = nn.Sequential(
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
)
else:
cond_embed_dim = 0
if global_cond_dim > 0:
# Global conditioning
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
self.to_global_embed = nn.Sequential(
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
)
if add_token_dim > 0:
# Conditioning tokens
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
self.to_add_embed = nn.Sequential(
nn.Linear(add_token_dim, add_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
)
else:
add_embed_dim = 0
if sync_token_dim > 0:
# Conditioning tokens
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
self.to_sync_embed = nn.Sequential(
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
)
else:
sync_embed_dim = 0
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
self.input_concat_dim = input_concat_dim
dim_in = io_channels + self.input_concat_dim
self.patch_size = patch_size
# Transformer
self.transformer_type = transformer_type
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.global_cond_type = global_cond_type
if self.transformer_type == "continuous_transformer":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend = cond_token_dim > 0,
cond_token_dim = cond_embed_dim,
global_cond_dim=global_dim,
**kwargs
)
else:
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
nn.init.zeros_(self.preprocess_conv.weight)
nn.init.zeros_(self.postprocess_conv.weight)
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
# if isinstance(module, nn.Conv1d):
# if module.bias is not None:
# nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
# Zero-out output layers:
if self.global_cond_type == "adaLN":
for block in self.transformer.layers:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.empty_clip_feat, 0)
nn.init.constant_(self.empty_sync_feat, 0)
def _forward(
self,
x,
t,
mask=None,
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
add_masks=None,
sync_cond=None,
return_info=False,
**kwargs):
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# reshape from (b, n, c) to (b, c, n)
if input_concat_cond.shape[1] != x.shape[1]:
input_concat_cond = input_concat_cond.transpose(1,2)
# Interpolate input_concat_cond to the same length as x
# if input_concat_cond.shape[1] != x.shape[2]:
# input_concat_cond = input_concat_cond.transpose(1,2)
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
# input_concat_cond = input_concat_cond.transpose(1,2)
# if len(global_embed.shape) == 2:
# global_embed = global_embed.unsqueeze(1)
# global_embed = global_embed + input_concat_cond
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if self.timestep_cond_type == "global":
if global_embed is not None:
if len(global_embed.shape) == 3:
timestep_embed = timestep_embed.unsqueeze(1)
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
elif self.timestep_cond_type == "input_concat":
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend" and global_embed is not None:
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
if len(global_embed.shape) == 2:
prepend_inputs = global_embed.unsqueeze(1)
else:
prepend_inputs = global_embed
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
if len(global_embed.shape) == 2:
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
else:
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
b, seq_len, c = x.shape
# 计算需要填充的数量
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
if pad_amount > 0:
# 在时间维度上进行填充
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if add_cond is not None:
# Interpolate add_cond to the same length as x
# if self.use_mlp:
add_cond = self.to_add_embed(add_cond)
if add_cond.shape[1] != x.shape[1]:
add_cond = add_cond.transpose(1,2)
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
add_cond = add_cond.transpose(1,2)
# add_cond = resample(add_cond, x)
if sync_cond is not None:
sync_cond = self.to_sync_embed(sync_cond)
if self.transformer_type == "continuous_transformer":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
# 移除之前添加的填充
if pad_amount > 0:
output = output[:, :, :seq_len]
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
sync_cond=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
bsz, a, b = x.shape
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
t = t.to(model_dtype)
if cross_attn_cond is not None:
cross_attn_cond = cross_attn_cond.to(model_dtype)
if negative_cross_attn_cond is not None:
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
if input_concat_cond is not None:
input_concat_cond = input_concat_cond.to(model_dtype)
if global_embed is not None:
global_embed = global_embed.to(model_dtype)
if negative_global_embed is not None:
negative_global_embed = negative_global_embed.to(model_dtype)
if prepend_cond is not None:
prepend_cond = prepend_cond.to(model_dtype)
if add_cond is not None:
add_cond = add_cond.to(model_dtype)
if sync_cond is not None:
sync_cond = sync_cond.to(model_dtype)
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if add_cond is not None:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
add_cond = torch.where(dropout_mask, null_embed, add_cond)
if sync_cond is not None:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None and global_embed.shape[0] == bsz:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
elif global_embed is not None:
batch_global_cond = global_embed
else:
batch_global_cond = None
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
elif input_concat_cond is not None:
batch_input_concat_cond = input_concat_cond
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
elif cross_attn_cond is not None:
batch_cond = cross_attn_cond
else:
batch_cond = None
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
elif prepend_cond is not None:
batch_prepend_cond = prepend_cond
else:
batch_prepend_cond = None
batch_add_cond = None
# Handle CFG for cross-attention conditioning
if add_cond is not None and add_cond.shape[0] == bsz:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
elif add_cond is not None:
batch_add_cond = add_cond
else:
batch_add_cond = None
batch_sync_cond = None
# Handle CFG for cross-attention conditioning
if sync_cond is not None and sync_cond.shape[0] == bsz:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
elif sync_cond is not None:
batch_sync_cond = sync_cond
else:
batch_sync_cond = None
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask = batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed = batch_global_cond,
prepend_cond = batch_prepend_cond,
prepend_cond_mask = batch_prepend_cond_mask,
add_cond = batch_add_cond,
sync_cond = batch_sync_cond,
return_info = return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
add_cond=add_cond,
sync_cond=sync_cond,
mask=mask,
return_info=return_info,
**kwargs
)
-275
View File
@@ -1,275 +0,0 @@
import torch
from einops import rearrange
from torch import nn
from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
from .utils import checkpoint
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_in = None,
dim_out = None,
causal = False,
local_attn_window_size = 64,
heads = 8,
ff_mult = 2,
cond_dim = 0,
cross_attn_cond_dim = 0,
**kwargs
):
super().__init__()
dim_head = dim//heads
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
self.local_attn_window_size = local_attn_window_size
self.cond_dim = cond_dim
self.cross_attn_cond_dim = cross_attn_cond_dim
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
for _ in range(depth):
self.layers.append(nn.ModuleList([
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
Attention(
dim=dim,
dim_heads=dim_head,
causal=causal,
zero_init_output=True,
natten_kernel_size=local_attn_window_size,
),
Attention(
dim=dim,
dim_heads=dim_head,
dim_context = cross_attn_cond_dim,
zero_init_output=True
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
]))
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
x = checkpoint(self.project_in, x)
if prepend_cond is not None:
x = torch.cat([prepend_cond, x], dim=1)
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
residual = x
if cond is not None:
x = checkpoint(attn_norm, x, cond)
else:
x = checkpoint(attn_norm, x)
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
if cross_attn_cond is not None:
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
residual = x
if cond is not None:
x = checkpoint(ff_norm, x, cond)
else:
x = checkpoint(ff_norm, x)
x = checkpoint(ff, x) + residual
return checkpoint(self.project_out, x)
class TransformerDownsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim = 768,
depth = 3,
heads = 12,
downsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.downsample_ratio = downsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size=local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
def forward(self, x):
x = checkpoint(self.project_in, x)
# Compute
x = self.transformer(x)
# Trade sequence length for channels
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
# Project back to embed dim
x = checkpoint(self.project_down, x)
return x
class TransformerUpsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
depth = 3,
heads = 12,
upsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.upsample_ratio = upsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size = local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
def forward(self, x):
# Project to embed dim
x = checkpoint(self.project_in, x)
# Project to increase channel dim
x = checkpoint(self.project_up, x)
# Trade channels for sequence length
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
# Compute
x = self.transformer(x)
return x
class TransformerEncoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [96, 192, 384, 768],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerDownsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
downsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
class TransformerDecoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [768, 384, 192, 96],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerUpsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
upsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
@@ -1 +0,0 @@
# mmmodules package
@@ -1 +0,0 @@
# mmmodules.model package
@@ -1,95 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
class ChannelLastConv1d(nn.Conv1d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = super().forward(x)
x = x.permute(0, 2, 1)
return x
# https://github.com/Stability-AI/sd3-ref
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w2 = ChannelLastConv1d(hidden_dim,
dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w3 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
-393
View File
@@ -1,393 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from scipy.optimize import fmin
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
class PQMF(nn.Module):
"""
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
Uses polyphase representation which is computationally more efficient for real-time.
Parameters:
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
"""
def __init__(self, attenuation, num_bands):
super(PQMF, self).__init__()
# Ensure num_bands is a power of 2
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
assert is_power_of_2, "'num_bands' must be a power of 2."
# Create the prototype filter
prototype_filter = design_prototype_filter(attenuation, num_bands)
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
# Register filters and settings
self.register_buffer("filter_bank", padded_filter_bank)
self.register_buffer("prototype", prototype_filter)
self.num_bands = num_bands
def forward(self, signal):
"""Decompose the signal into multiple frequency bands."""
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
signal = prepare_signal_dimensions(signal)
# The signal length must be a multiple of num_bands. Pad it with zeros.
signal = pad_signal(signal, self.num_bands)
# run it
signal = polyphase_analysis(signal, self.filter_bank)
return apply_alias_cancellation(signal)
def inverse(self, bands):
"""Reconstruct the original signal from the frequency bands."""
bands = apply_alias_cancellation(bands)
return polyphase_synthesis(bands, self.filter_bank)
def prepare_signal_dimensions(signal):
"""
Rearrange signal into Batch x Channels x Length.
Parameters
----------
signal : torch.Tensor or numpy.ndarray
The input signal.
Returns
-------
torch.Tensor
Preprocessed signal tensor.
"""
# Convert numpy to torch tensor
if isinstance(signal, np.ndarray):
signal = torch.from_numpy(signal)
# Ensure tensor
if not isinstance(signal, torch.Tensor):
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
# Modify dimension of signal to Batch x Channels x Length
if signal.dim() == 1:
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
signal = signal.unsqueeze(0).unsqueeze(0)
elif signal.dim() == 2:
# This is a multi-channel signal (e.g. stereo)
# Rearrange so that larger dimension (Length) is last
if signal.shape[0] > signal.shape[1]:
signal = signal.T
# Unsqueeze to 1 x Channels x Length
signal = signal.unsqueeze(0)
return signal
def pad_signal(signal, num_bands):
"""
Pads the signal to make its length divisible by the given number of bands.
Parameters
----------
signal : torch.Tensor
The input signal tensor, where the last dimension represents the signal length.
num_bands : int
The number of bands by which the signal length should be divisible.
Returns
-------
torch.Tensor
The padded signal tensor. If the original signal length was already divisible
by num_bands, returns the original signal unchanged.
"""
remainder = signal.shape[-1] % num_bands
if remainder > 0:
padding_size = num_bands - remainder
signal = nn.functional.pad(signal, (0, padding_size))
return signal
def generate_modulated_filter_bank(prototype_filter, num_bands):
"""
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
Parameters
----------
prototype_filter : torch.Tensor
The prototype filter used as the basis for modulation.
num_bands : int
The number of desired subbands or filters.
Returns
-------
torch.Tensor
A bank of cosine modulated filters.
"""
# Initialize indices for modulation.
subband_indices = torch.arange(num_bands).reshape(-1, 1)
# Calculate the length of the prototype filter.
filter_length = prototype_filter.shape[-1]
# Generate symmetric time indices centered around zero.
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
# Calculate phase offsets to ensure orthogonality between subbands.
phase_offsets = (-1)**subband_indices * np.pi / 4
# Compute the cosine modulation function.
modulation = torch.cos(
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
)
# Apply modulation to the prototype filter.
modulated_filters = 2 * prototype_filter * modulation
return modulated_filters
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
"""
Design a lowpass filter using the Kaiser window.
Parameters
----------
angular_cutoff : float
The angular frequency cutoff of the filter.
attenuation : float
The desired stopband attenuation in decibels (dB).
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The designed lowpass filter coefficients.
"""
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
# Ensure the estimated length is odd.
estimated_length = 2 * (estimated_length // 2) + 1
if filter_length is None:
filter_length = estimated_length
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
"""
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
Parameters
----------
angular_cutoff : float
Angular frequency cutoff of the filter.
attenuation : float
Desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter.
Returns
-------
float
The computed objective (loss) value for the given filter specs.
"""
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
def design_prototype_filter(attenuation, num_bands, filter_length=None):
"""
Design the optimal prototype filter for a multiband system given the desired specs.
Parameters
----------
attenuation : float
The desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The optimal prototype filter coefficients.
"""
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
1 / num_bands, disp=0)[0]
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
return torch.tensor(prototype_filter, dtype=torch.float32)
def pad_to_nearest_power_of_two(x):
"""
Pads the input tensor 'x' on both sides such that its last dimension
becomes the nearest larger power of two.
Parameters:
-----------
x : torch.Tensor
The input tensor to be padded.
Returns:
--------
torch.Tensor
The padded tensor.
"""
current_length = x.shape[-1]
target_length = 2**math.ceil(math.log2(current_length))
total_padding = target_length - current_length
left_padding = total_padding // 2
right_padding = total_padding - left_padding
return nn.functional.pad(x, (left_padding, right_padding))
def apply_alias_cancellation(x):
"""
Applies alias cancellation by inverting the sign of every
second element of every second row, starting from the second
row's first element in a tensor.
This operation helps ensure that the aliasing introduced in
each band during the decomposition will be counteracted during
the reconstruction.
Parameters:
-----------
x : torch.Tensor
The input tensor.
Returns:
--------
torch.Tensor
Tensor with specific elements' sign inverted for alias cancellation.
"""
# Create a mask of the same shape as 'x', initialized with all ones
mask = torch.ones_like(x)
# Update specific elements in the mask to -1 to perform inversion
mask[..., 1::2, ::2] = -1
# Apply the mask to the input tensor 'x'
return x * mask
def ensure_odd_length(tensor):
"""
Pads the last dimension of a tensor to ensure its size is odd.
Parameters:
-----------
tensor : torch.Tensor
Input tensor whose last dimension might need padding.
Returns:
--------
torch.Tensor
The original tensor if its last dimension was already odd,
or the padded tensor with an odd-sized last dimension.
"""
last_dim_size = tensor.shape[-1]
if last_dim_size % 2 == 0:
tensor = nn.functional.pad(tensor, (0, 1))
return tensor
def polyphase_analysis(signal, filter_bank):
"""
Applies the polyphase method to efficiently analyze the signal using a filter bank.
Parameters:
-----------
signal : torch.Tensor
Input signal tensor with shape (Batch x Channels x Length).
filter_bank : torch.Tensor
Filter bank tensor with shape (Bands x Length).
Returns:
--------
torch.Tensor
Signal split into sub-bands. (Batch x Channels x Bands x Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange signal for polyphase processing.
# Also combine Batch x Channel into one dimension for now.
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
# Rearrange the filter bank for matching signal shape
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
# Apply convolution with appropriate padding to maintain spatial dimensions
padding = filter_bank.shape[-1] // 2
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
# Truncate the last dimension post-convolution to adjust the output shape
filtered_signal = filtered_signal[..., :-1]
# Rearrange the first dimension back into Batch x Channels
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
return filtered_signal
def polyphase_synthesis(signal, filter_bank):
"""
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
Parameters
----------
signal : torch.Tensor
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
filter_bank : torch.Tensor
Analysis filter bank (shape: Bands x Length).
should_rearrange : bool, optional
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
Returns
-------
torch.Tensor
Reconstructed signal (shape: Batch x Channels X Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange the filter bank
filter_bank = filter_bank.flip(-1)
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
# Combine Batch x Channels into one dimension for now.
signal = rearrange(signal, "b c n t -> (b c) n t")
# Apply convolution with appropriate padding
padding_amount = filter_bank.shape[-1] // 2 + 1
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
# Scale the result
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
# Reorganize the output and truncate
reconstructed_signal = reconstructed_signal.flip(1)
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
return reconstructed_signal
-239
View File
@@ -1,239 +0,0 @@
import torch
from einops import rearrange
from torch import nn
class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
self.is_discrete = is_discrete
self.io_channels = io_channels
self.encoded_channels = None
self.downsampling_ratio = None
self.enable_grad = enable_grad
def encode(self, x):
raise NotImplementedError
def decode(self, z):
raise NotImplementedError
def tokenize(self, x):
raise NotImplementedError
def decode_tokens(self, tokens):
raise NotImplementedError
class AutoencoderPretransform(Pretransform):
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
self.model = model
self.model.requires_grad_(False).eval()
self.scale=scale
self.downsampling_ratio = model.downsampling_ratio
self.io_channels = model.io_channels
self.sample_rate = model.sample_rate
self.model_half = model_half
self.iterate_batch = iterate_batch
self.encoded_channels = model.latent_dim
self.chunked = chunked
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
if self.model_half:
self.model.half()
def encode(self, x, **kwargs):
if self.model_half:
x = x.half()
self.model.to(torch.float16)
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
encoded = encoded.float()
return encoded / self.scale
def decode(self, z, **kwargs):
z = z * self.scale
if self.model_half:
z = z.half()
self.model.to(torch.float16)
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
decoded = decoded.float()
return decoded
def tokenize(self, x, **kwargs):
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
_, info = self.model.encode(x, return_info = True, **kwargs)
return info[self.model.bottleneck.tokens_id]
def decode_tokens(self, tokens, **kwargs):
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
return self.model.decode_tokens(tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.model.load_state_dict(state_dict, strict=strict)
class PQMFPretransform(Pretransform):
def __init__(self, attenuation=100, num_bands=16):
# TODO: Fix PQMF to take in in-channels
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
from .pqmf import PQMF
self.pqmf = PQMF(attenuation, num_bands)
def encode(self, x):
# x is (Batch x Channels x Time)
x = self.pqmf.forward(x)
# pqmf.forward returns (Batch x Channels x Bands x Time)
# but Pretransform needs Batch x Channels x Time
# so concatenate channels and bands into one axis
return rearrange(x, "b c n t -> b (c n) t")
def decode(self, x):
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
# returns (Batch x Channels x Time)
return self.pqmf.inverse(x)
class PretrainedDACPretransform(Pretransform):
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
import dac
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
self.model = dac.DAC.load(model_path)
self.quantize_on_decode = quantize_on_decode
if model_type == "44khz":
self.downsampling_ratio = 512
else:
self.downsampling_ratio = 320
self.io_channels = 1
self.scale = scale
self.chunked = chunked
self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.n_codebooks
self.codebook_size = self.model.codebook_size
def encode(self, x):
latents = self.model.encoder(x)
if self.quantize_on_decode:
output = latents
else:
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
output = z
if self.scale != 1.0:
output = output / self.scale
return output
def decode(self, z):
if self.scale != 1.0:
z = z * self.scale
if self.quantize_on_decode:
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
return self.model.decode(z)
def tokenize(self, x):
return self.model.encode(x)[1]
def decode_tokens(self, tokens):
latents = self.model.quantizer.from_codes(tokens)
return self.model.decode(latents)
class AudiocraftCompressionPretransform(Pretransform):
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
try:
from audiocraft.models import CompressionModel
except ImportError:
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
self.model = CompressionModel.get_pretrained(model_type)
self.quantize_on_decode = quantize_on_decode
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
self.sample_rate = self.model.sample_rate
self.io_channels = self.model.channels
self.scale = scale
#self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.num_codebooks
self.codebook_size = self.model.cardinality
self.model.to(torch.float16).eval().requires_grad_(False)
def encode(self, x):
assert False, "Audiocraft compression models do not support continuous encoding"
# latents = self.model.encoder(x)
# if self.quantize_on_decode:
# output = latents
# else:
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
# output = z
# if self.scale != 1.0:
# output = output / self.scale
# return output
def decode(self, z):
assert False, "Audiocraft compression models do not support continuous decoding"
# if self.scale != 1.0:
# z = z * self.scale
# if self.quantize_on_decode:
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
# return self.model.decode(z)
def tokenize(self, x):
with torch.cuda.amp.autocast(enabled=False):
return self.model.encode(x.to(torch.float16))[0]
def decode_tokens(self, tokens):
with torch.cuda.amp.autocast(enabled=False):
return self.model.decode(tokens)
-989
View File
@@ -1,989 +0,0 @@
from functools import reduce, partial
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from typing import Callable, Literal
try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
flash_attn_kvpacked_func = None
flash_attn_func = None
from .utils import compile, checkpoint
try:
import natten
except ImportError:
natten = None
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device = device)
return self.forward(t)
@autocast(enabled = False)
def forward(self, t):
device = self.inv_freq.device
t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim = -1)
# norms
class DynamicTanh(nn.Module):
def __init__(self, dim, init_alpha=10.0):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
x = F.tanh(self.alpha * x)
return self.gamma * x + self.beta
class RunningInstanceNorm(nn.Module):
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
super().__init__()
self.register_buffer("running_mean", torch.zeros(1,1,dim))
self.register_buffer("running_std", torch.ones(1,1,dim))
self.saturate = saturate
self.eps = eps
self.momentum = momentum
self.dim = dim
self.trainable_gain = trainable_gain
if self.trainable_gain:
self.gain = nn.Parameter(torch.ones(1))
def _update_stats(self, x):
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
def forward(self, x):
if self.training:
self._update_stats(x)
x = (x - self.running_mean) / self.running_std
if self.saturate:
x = torch.asinh(x)
if self.trainable_gain:
x = x * self.gain
return x
class LayerNorm(nn.Module):
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
if fix_scale:
self.register_buffer("gamma", torch.ones(dim))
else:
self.gamma = nn.Parameter(torch.ones(dim))
if bias:
self.beta = nn.Parameter(torch.zeros(dim))
else:
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.force_fp32 = force_fp32
def forward(self, x):
if not self.force_fp32:
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
else:
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
return output.to(x.dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_val = 1e-5):
super().__init__()
self.scale = nn.Parameter(torch.full([dim], init_val))
def forward(self, x):
return x * self.scale
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
use_conv = False,
conv_kernel_size = 3,
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim = -1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
no_bias = False,
glu = True,
use_conv = False,
conv_kernel_size = 3,
zero_init_output = True,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
# init last linear layer to 0
if zero_init_output:
nn.init.zeros_(linear_out.weight)
if not no_bias:
nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
differential = False,
feat_scale = False
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.differential = differential
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
if differential:
self.to_q = nn.Linear(dim, dim * 2, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
else:
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
else:
if differential:
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim, bias=False)
if zero_init_output:
nn.init.zeros_(self.to_out.weight)
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
elif self.qk_norm == 'rns':
self.q_norm = nn.RMSNorm(dim_heads)
self.k_norm = nn.RMSNorm(dim_heads)
elif self.qk_norm == 'dyt':
self.q_norm = DynamicTanh(dim_heads)
self.k_norm = DynamicTanh(dim_heads)
self.sdp_kwargs = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
self.feat_scale = feat_scale
if self.feat_scale:
self.lambda_dc = nn.Parameter(torch.zeros(dim))
self.lambda_hf = nn.Parameter(torch.zeros(dim))
self.causal = causal
@compile
def apply_qk_layernorm(self, q, k):
q_type = q.dtype
k_type = k.dtype
q = self.q_norm(q).to(q_type)
k = self.k_norm(k).to(k_type)
return q, k
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
if self.num_heads != self.kv_heads:
# Repeat interleave kv_heads to match q_heads for grouped query attention
heads_per_kv_head = self.num_heads // self.kv_heads
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
flash_attn_available = HAS_FLASH_ATTN
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
flex_attention_block_mask = None
flex_attention_score_mod = None
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
raise NotImplementedError(
"FlexAttention is not available in this build. "
"flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments."
)
elif flash_attn_available:
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
else:
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
return out
#@compile
def forward(
self,
x,
context = None,
rotary_pos_emb = None,
causal = None,
flex_attention_block_mask = None,
flex_attention_score_mod = None,
flash_attn_sliding_window = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
if self.differential:
q, q_diff = self.to_q(x).chunk(2, dim=-1)
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
q = torch.stack([q, q_diff], dim = 1)
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
k = torch.stack([k, k_diff], dim = 1)
else:
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
if self.differential:
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
q = torch.stack([q, q_diff], dim = 1)
k = torch.stack([k, k_diff], dim = 1)
else:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm != "none":
q, k = self.apply_qk_layernorm(q, k)
if rotary_pos_emb is not None:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
if q.shape[-2] >= k.shape[-2]:
ratio = q.shape[-2] / k.shape[-2]
q_freqs, k_freqs = freqs, ratio * freqs
else:
ratio = k.shape[-2] / q.shape[-2]
q_freqs, k_freqs = ratio * freqs, freqs
q = apply_rotary_pos_emb(q, q_freqs)
k = apply_rotary_pos_emb(k, k_freqs)
q = q.to(v.dtype)
k = k.to(v.dtype)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if self.differential:
q, q_diff = q.unbind(dim = 1)
k, k_diff = k.unbind(dim = 1)
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out = out - out_diff
else:
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
# merge heads
out = rearrange(out, ' b h n d -> b n (h d)')
# Communicate between heads
# with autocast(enabled = False):
# out_dtype = out.dtype
# out = out.to(torch.float32)
# out = self.to_out(out).to(out_dtype)
out = self.to_out(out)
if self.feat_scale:
out_dc = out.mean(dim=-2, keepdim=True)
out_hf = out - out_dc
# Selectively modulate DC and high frequency components
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs = {},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
#@compile
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
cross_attend = False,
dim_context = None,
global_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
add_rope = False,
layer_scale = False,
use_sync_block_film = False,
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {}
):
super().__init__()
self.dim = dim
self.dim_heads = min(dim_heads,dim)
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
if layer_scale and zero_init_branch_outputs:
zero_init_branch_outputs = False
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.add_rope = add_rope
self.self_attn = Attention(
dim,
dim_heads = self.dim_heads,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.cross_attend = cross_attend
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.cross_attn = Attention(
dim,
dim_heads = self.dim_heads,
dim_context=dim_context,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.layer_ix = layer_ix
self.conformer = None
if conformer:
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
if use_sync_block_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
@compile
def forward(
self,
x,
context = None,
global_cond=None,
rotary_pos_emb = None,
self_attention_block_mask = None,
self_attention_score_mod = None,
cross_attention_block_mask = None,
cross_attention_score_mod = None,
self_attention_flash_sliding_window = None,
cross_attention_flash_sliding_window = None,
sync_cond = None,
prepend_length=0
):
if rotary_pos_emb is None and self.add_rope:
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
if len(global_cond.shape) == 2:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
else:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
x = x * torch.sigmoid(1 - gate_self)
x = self.self_attn_scale(x)
x = x + residual
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = self.ff_scale(x)
x = x + residual
else:
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
prepend_part = x[:, :prepend_length, :]
audio_part = x[:, prepend_length:, :]
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
modulated_audio_part = audio_part * (1 + scale) + shift
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in = None,
dim_out = None,
dim_heads = 64,
cross_attend=False,
cond_token_dim=None,
pre_cross_attn_ix=-1,
final_cross_attn_ix=-1,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
num_memory_tokens=0,
sliding_window=None,
use_mlp=False,
use_add_norm=False,
use_gated=False,
use_final_layer=False,
use_zeros=False,
use_conv=False,
use_fusion_mlp=False,
use_film=False,
use_sync_film=False,
use_sync_gated=False,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
if use_mlp:
self.project_in = nn.Sequential(
nn.Linear(dim_in, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim, bias=False)
)
else:
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
self.video_temporal_conv = None
self.audio_temporal_conv = None
self.fusion_mlp = None
if use_conv:
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
if use_fusion_mlp:
self.fusion_mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
self.adaLN_modulation = None
if global_cond_dim is not None:
if use_final_layer:
self.norm_final = LayerNorm(dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(
dim, 2 * dim, bias=True
),
)
if use_zeros:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.project_out.weight, 0)
self.global_cond_embedder = nn.Sequential(
nn.Linear(global_cond_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6)
)
if use_zeros:
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
self.final_cross_attn_ix = final_cross_attn_ix
self.use_gated = use_gated
self.use_film = use_film
self.use_add_norm = use_add_norm
if self.use_add_norm:
self.add_norm = nn.LayerNorm(dim)
if use_gated:
self.gate = nn.Parameter(torch.ones(1, 1, dim))
if use_film:
self.film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.film_generator = None
if use_sync_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.sync_film_generator = None
if use_sync_gated:
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
else:
self.sync_gate = None
self.sliding_window = sliding_window
for i in range(depth):
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
# print(f"Layer {i} cross attends: {should_cross_attend}")
self.layers.append(
TransformerBlock(
dim,
dim_heads = dim_heads,
cross_attend = should_cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
**kwargs
)
)
def forward(
self,
x,
mask = None,
prepend_embeds = None,
prepend_mask = None,
add_cond = None,
sync_cond = None,
global_cond = None,
return_info = False,
use_checkpointing = True,
exit_layer_ix = None,
video_dropout_prob = 0.0,
**kwargs
):
batch, seq, device = *x.shape[:2], x.device
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
prepend_length = 0
info = {
"hidden_states": [],
}
x = self.project_in(x)
if add_cond is not None:
if self.use_gated:
gate = torch.sigmoid(self.gate)
x = x + gate * add_cond
elif self.use_film:
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
else:
x = x + add_cond
if self.use_add_norm:
x = self.add_norm(x)
if self.fusion_mlp is not None:
x = self.fusion_mlp(x)
if sync_cond is not None:
# Resample sync_cond to match audio sequence length if needed
if sync_cond.shape[1] != x.shape[1]:
sync_cond = torch.nn.functional.interpolate(
sync_cond.transpose(1, 2), size=x.shape[1],
mode='linear', align_corners=False,
).transpose(1, 2)
if self.sync_film_generator is not None:
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
elif self.sync_gate is not None:
gate_value = torch.sigmoid(self.sync_gate)
x = x + gate_value * sync_cond
# else:
# x = x + sync_cond
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim = -2)
if self.num_memory_tokens > 0:
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
if global_cond is not None and self.global_cond_embedder is not None:
global_cond_embed = self.global_cond_embedder(global_cond)
else:
global_cond_embed = global_cond
# Iterate over the transformer layers
for layer_ix, layer in enumerate(self.layers):
if use_checkpointing:
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
if return_info:
info["hidden_states"].append(x)
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
x = x[:, self.num_memory_tokens:, :]
if return_info:
return x, info
return x
x = x[:, self.num_memory_tokens:, :]
if global_cond is not None and self.adaLN_modulation is not None:
if len(global_cond.shape) == 2:
global_cond = global_cond.unsqueeze(1)
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.project_out(x)
if return_info:
return x, info
return x
-180
View File
@@ -1,180 +0,0 @@
import torch
from safetensors.torch import load_file
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
from torch.nn.utils import remove_weight_norm
def load_ckpt_state_dict(ckpt_path, prefix=None):
if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# 过滤特定前缀的state_dict
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
return filtered_state_dict
def remove_weight_norm_from_model(model):
for module in model.modules():
if hasattr(module, "weight"):
remove_weight_norm(module)
return model
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
# License can be found in LICENSES/LICENSE_META.txt
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def next_power_of_two(n):
return 2 ** (n - 1).bit_length()
def next_multiple_of_64(n):
return ((n + 63) // 64) * 64
# mask construction helpers
def mask_from_start_end_indices(
seq_len: int,
start: Tensor,
end: Tensor
):
assert start.shape == end.shape
device = start.device
seq = torch.arange(seq_len, device = device, dtype = torch.long)
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
seq = seq.expand(*start.shape, seq_len)
mask = seq >= start[..., None].long()
mask &= seq < end[..., None].long()
return mask
def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths.device
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def _build_spline(video_feat, video_t, target_t):
# 三次样条插值核心实现
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
spline = NaturalCubicSpline(coeffs)
return spline.evaluate(target_t).permute(0,2,1)
def resample(video_feat, audio_latent):
"""
9s
video_feat: [B, 72, D]
audio_latent: [B, D', 194] or int
"""
B, Tv, D = video_feat.shape
if isinstance(audio_latent, torch.Tensor):
# audio_latent is a tensor
if audio_latent.shape[1] != 64:
Ta = audio_latent.shape[1]
else:
Ta = audio_latent.shape[2]
elif isinstance(audio_latent, int):
# audio_latent is an int
Ta = audio_latent
else:
raise TypeError("audio_latent must be either a tensor or an int")
# 构建时间戳 (关键改进点)
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
# 三维化处理 (Batch, Feature, Time)
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
# 三次样条插值
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
import os
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
def compile(function, *args, **kwargs):
if enable_torch_compile:
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
return function
-6
View File
@@ -1,11 +1,5 @@
einops>=0.7.0
einops-exts
safetensors
huggingface_hub
transformers>=4.52.3
k-diffusion>=0.1.1
alias-free-torch
descript-audio-codec
vector-quantize-pytorch
scipy
tqdm
-21
View File
@@ -1,21 +0,0 @@
name: prismaudio-extract
channels:
- conda-forge
- defaults
dependencies:
- python=3.10
- pip
- ffmpeg<7
- pip:
- torch>=2.6.0
- torchaudio>=2.6.0
- torchvision>=0.21.0
- tensorflow-cpu==2.15.0
- jax
- jaxlib
- transformers>=4.52.3
- decord
- einops>=0.7.0
- numpy
- mediapy
- git+https://github.com/google-deepmind/videoprism.git
-168
View File
@@ -1,168 +0,0 @@
#!/usr/bin/env python3
"""
Standalone PrismAudio feature extraction script.
Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor).
Usage:
python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz
"""
import argparse
import os
import sys
import time
import numpy as np
import torch
# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR)
if _PLUGIN_DIR not in sys.path:
sys.path.insert(0, _PLUGIN_DIR)
def _step(n, total, label):
"""Print step header and return start time."""
print(f"[extract] Step {n}/{total}{label}...", flush=True)
return time.perf_counter()
def _done(t0, extra=""):
elapsed = time.perf_counter() - t0
suffix = f" {extra}" if extra else ""
print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True)
def main():
t_total = time.perf_counter()
parser = argparse.ArgumentParser(description="PrismAudio feature extraction")
parser.add_argument("--video", required=True, help="Path to input video")
parser.add_argument("--cot_text", required=True, help="Chain-of-thought description")
parser.add_argument("--output", required=True, help="Output .npz path")
parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint")
parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON")
parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)")
parser.add_argument("--clip_fps", type=float, default=4.0)
parser.add_argument("--clip_size", type=int, default=288)
parser.add_argument("--sync_fps", type=float, default=25.0)
parser.add_argument("--sync_size", type=int, default=224)
args = parser.parse_args()
print(f"[extract] Python : {sys.executable}", flush=True)
print(f"[extract] Video : {args.video}", flush=True)
print(f"[extract] Output : {args.output}", flush=True)
print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True)
if not os.path.exists(args.video):
print(f"[extract] ERROR: video not found: {args.video}", flush=True)
sys.exit(1)
print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------------
t0 = _step(1, 6, "importing dependencies")
from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils
import torchvision.transforms as T
_done(t0)
# ------------------------------------------------------------------
t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)")
feat_utils = FeaturesUtils(
vae_config_path=args.vae_config,
synchformer_ckpt=args.synchformer_ckpt,
device=device,
)
_done(t0)
# ------------------------------------------------------------------
t0 = _step(3, 6, "reading and preprocessing video")
if args.video.endswith(".npy"):
all_frames = np.load(args.video) # [T, H, W, C] uint8
fps = args.source_fps
total_frames = all_frames.shape[0]
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = all_frames[clip_indices]
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = all_frames[sync_indices]
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
else:
from decord import VideoReader, cpu
vr = VideoReader(args.video, ctx=cpu(0))
fps = vr.get_avg_fps()
total_frames = len(vr)
duration = total_frames / fps
print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True)
clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))]
clip_indices = [min(i, total_frames - 1) for i in clip_indices]
clip_frames = vr.get_batch(clip_indices).asnumpy()
print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True)
sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))]
sync_indices = [min(i, total_frames - 1) for i in sync_indices]
sync_frames = vr.get_batch(sync_indices).asnumpy()
print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True)
clip_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.clip_size),
T.CenterCrop(args.clip_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device)
sync_transform = T.Compose([
T.ToPILImage(),
T.Resize(args.sync_size),
T.CenterCrop(args.sync_size),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device)
_done(t0)
# ------------------------------------------------------------------
t0 = _step(4, 6, "encoding text with T5-Gemma")
text_features = feat_utils.encode_t5_text([args.cot_text])
_done(t0, f"shape={tuple(text_features.shape)}")
# ------------------------------------------------------------------
t0 = _step(5, 6, "encoding video with VideoPrism")
global_video_features, video_features, global_text_features = \
feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text])
_done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}")
# ------------------------------------------------------------------
t0 = _step(6, 6, "encoding video with Synchformer")
sync_features = feat_utils.encode_video_with_sync(sync_input)
_done(t0, f"shape={tuple(sync_features.shape)}")
# ------------------------------------------------------------------
t0 = time.perf_counter()
print(f"[extract] Saving features to {args.output} ...", flush=True)
np.savez(
args.output,
video_features=video_features.cpu().float().numpy(),
global_video_features=global_video_features.cpu().float().numpy(),
text_features=text_features.cpu().float().numpy(),
global_text_features=global_text_features.cpu().float().numpy(),
sync_features=sync_features.cpu().float().numpy(),
caption_cot=args.cot_text,
duration=duration,
)
print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True)
print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True)
if __name__ == "__main__":
main()
-44
View File
@@ -1,44 +0,0 @@
#!/usr/bin/env bash
# Install the PrismAudio feature-extraction environment using pip venv.
# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker).
#
# Usage:
# bash scripts/install_extract_env.sh [/path/to/venv]
#
# Default venv path: /opt/prismaudio-extract
# After installation, point the PrismAudioFeatureExtractor node's python_env to:
# <venv>/bin/python (Linux/Mac)
# <venv>\Scripts\python.exe (Windows)
set -euo pipefail
VENV_DIR="${1:-/opt/prismaudio-extract}"
echo "[PrismAudio] Creating venv at: ${VENV_DIR}"
python3 -m venv "${VENV_DIR}"
PIP="${VENV_DIR}/bin/pip"
echo "[PrismAudio] Upgrading pip..."
"${PIP}" install --upgrade pip
echo "[PrismAudio] Installing PyTorch stack..."
"${PIP}" install torch torchaudio torchvision
echo "[PrismAudio] Installing feature-extraction dependencies..."
"${PIP}" install \
"tensorflow-cpu>=2.16.0" \
"jax[cpu]" \
"jaxlib" \
"transformers" \
"decord" \
"einops" \
"numpy" \
"mediapy"
echo "[PrismAudio] Installing VideoPrism..."
"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git"
echo ""
echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:"
echo " ${VENV_DIR}/bin/python"
-158
View File
@@ -1,158 +0,0 @@
{
"id": "a1c3e5f7-b2d4-4e6a-8c0f-1a3b5c7d9e2f",
"revision": 0,
"last_node_id": 3,
"last_link_id": 2,
"nodes": [
{
"id": 1,
"type": "PrismAudioModelLoader",
"pos": [
-160,
-224
],
"size": [
288,
96
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "model",
"type": "PRISMAUDIO_MODEL",
"slot_index": 0,
"links": [
1
]
}
],
"properties": {
"aux_id": "ethanfel/ComfyUI-Prismaudio",
"ver": "62a3c5d",
"Node name for S&R": "PrismAudioModelLoader",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.8",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"auto",
"auto"
]
},
{
"id": 2,
"type": "PrismAudioTextOnly",
"pos": [
192,
-224
],
"size": [
480,
222
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "PRISMAUDIO_MODEL",
"link": 1
}
],
"outputs": [
{
"name": "audio",
"type": "AUDIO",
"slot_index": 0,
"links": [
2
]
}
],
"properties": {
"aux_id": "ethanfel/ComfyUI-Prismaudio",
"ver": "62a3c5d",
"Node name for S&R": "PrismAudioTextOnly",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.8",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"A large dog barks sharply twice in an outdoor setting, with ambient background noise of rustling leaves and a gentle breeze. The sound is clear and close, recorded at ground level.",
10.0,
100,
7.0,
0,
"randomize"
]
},
{
"id": 3,
"type": "PreviewAudio",
"pos": [
736,
-224
],
"size": [
300,
76
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "audio",
"type": "AUDIO",
"link": 2
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewAudio"
},
"widgets_values": []
}
],
"links": [
[
1,
1,
0,
2,
0,
"PRISMAUDIO_MODEL"
],
[
2,
2,
0,
3,
0,
"AUDIO"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1.1674071890328979,
"offset": [
1814.5534800416863,
500.0421331448515
]
},
"ue_links": [],
"links_added_by_ue": [],
"frontendVersion": "1.42.8"
},
"version": 0.4
}
-421
View File
@@ -1,421 +0,0 @@
{
"id": "2481bfbf-ce24-46c5-abdc-1d9163ff78ae",
"revision": 0,
"last_node_id": 12,
"last_link_id": 30,
"nodes": [
{
"id": 1,
"type": "VHS_LoadVideo",
"pos": [
-704,
-256
],
"size": [
288,
474.7188081936685
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
12,
20
]
},
{
"name": "frame_count",
"type": "INT",
"slot_index": 1,
"links": []
},
{
"name": "audio",
"type": "AUDIO",
"slot_index": 2,
"links": []
},
{
"name": "video_info",
"type": "VHS_VIDEOINFO",
"slot_index": 3,
"links": [
21
]
}
],
"properties": {
"cnr_id": "comfyui-videohelpersuite",
"ver": "1.7.9",
"Node name for S&R": "VHS_LoadVideo",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.8",
"input_ue_unconnectable": {}
}
},
"widgets_values": {
"video": "Railtransport_3_479.mp4",
"force_rate": 0,
"custom_width": 0,
"custom_height": 0,
"frame_load_cap": 0,
"skip_first_frames": 0,
"select_every_nth": 1,
"format": "AnimateDiff",
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"force_rate": 0,
"frame_load_cap": 0,
"skip_first_frames": 0,
"select_every_nth": 1,
"filename": "Railtransport_3_479.mp4",
"type": "input",
"format": "video/mp4"
}
}
}
},
{
"id": 2,
"type": "PrismAudioModelLoader",
"pos": [
-160,
-224
],
"size": [
288,
96
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "model",
"type": "PRISMAUDIO_MODEL",
"slot_index": 0,
"links": [
26
]
}
],
"properties": {
"aux_id": "ethanfel/ComfyUI-Prismaudio",
"ver": "3894fcc9b40a19d959614d514d5dff65cdfb6eab",
"Node name for S&R": "PrismAudioModelLoader",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.8",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"auto",
"auto"
]
},
{
"id": 12,
"type": "PrismAudioSampler",
"pos": [
256,
-224
],
"size": [
384,
224
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "PRISMAUDIO_MODEL",
"link": 26
},
{
"name": "features",
"type": "PRISMAUDIO_FEATURES",
"link": 27
}
],
"outputs": [
{
"name": "audio",
"type": "AUDIO",
"links": [
29
]
}
],
"properties": {
"aux_id": "ethanfel/ComfyUI-Prismaudio",
"ver": "30631c0cb4d97cc6aed69a52e3ee4d89df03926c",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
},
"Node name for S&R": "PrismAudioSampler"
},
"widgets_values": [
0,
100,
7,
4096333446,
"randomize"
]
},
{
"id": 11,
"type": "PrismAudioFeatureExtractor",
"pos": [
-384,
-64
],
"size": [
544,
288
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "IMAGE",
"link": 20
},
{
"name": "video_info",
"shape": 7,
"type": "VHS_VIDEOINFO",
"link": 21
}
],
"outputs": [
{
"name": "features",
"type": "PRISMAUDIO_FEATURES",
"links": [
27
]
},
{
"name": "fps",
"type": "FLOAT",
"links": [
30
]
}
],
"properties": {
"aux_id": "ethanfel/ComfyUI-Prismaudio",
"ver": "5b62be04471bf118b2cd3cc71431a302f5730b01",
"Node name for S&R": "PrismAudioFeatureExtractor",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.8"
}
},
"widgets_values": [
"Generate ambient countryside sounds with a gentle breeze rustling the leaves of a large tree. From the right, introduce a faint rumble of wheels on a track and a steam engine chugging. Allow the sounds to grow louder and pan from right to left as the steam train travels across the landscape. Include the powerful chugging and clattering of carriages in the soundscape, then gradually recede the sounds to the left. Ensure no additional background noise or music is present.\n",
30,
"managed_env",
"/media/unraid/comfyui/output/prismaudiocache/",
""
]
},
{
"id": 9,
"type": "VHS_VideoCombine",
"pos": [
704,
-256
],
"size": [
384,
552.75
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 12
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": 29
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
},
{
"name": "frame_rate",
"type": "FLOAT",
"widget": {
"name": "frame_rate"
},
"link": 30
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"cnr_id": "comfyui-videohelpersuite",
"ver": "1.7.9",
"Node name for S&R": "VHS_VideoCombine",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.8"
}
},
"widgets_values": {
"frame_rate": 30,
"loop_count": 0,
"filename_prefix": "AnimateDiff",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": false,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "AnimateDiff_00001-audio.mp4",
"subfolder": "",
"type": "temp",
"format": "video/h264-mp4",
"frame_rate": 30,
"workflow": "AnimateDiff_00001.png",
"fullpath": "/basedir/temp/AnimateDiff_00001-audio.mp4"
}
}
}
}
],
"links": [
[
12,
1,
0,
9,
0,
"IMAGE"
],
[
20,
1,
0,
11,
0,
"IMAGE"
],
[
21,
1,
3,
11,
1,
"VHS_VIDEOINFO"
],
[
26,
2,
0,
12,
0,
"PRISMAUDIO_MODEL"
],
[
27,
11,
0,
12,
1,
"PRISMAUDIO_FEATURES"
],
[
29,
12,
0,
9,
1,
"AUDIO"
],
[
30,
11,
1,
9,
4,
"FLOAT"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1.1674071890328979,
"offset": [
1814.5534800416863,
500.0421331448515
]
},
"ue_links": [],
"links_added_by_ue": [],
"frontendVersion": "1.42.8",
"VHS_latentpreview": true,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}