diff --git a/__init__.py b/__init__.py index c9d8310..cbe7c8f 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,5 @@ """ -ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). +ComfyUI-SelVA: Text-guided video-to-audio generation using SelVA / MMAudio. """ import sys import os diff --git a/data_utils/__init__.py b/data_utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/data_utils/v2a_utils/__init__.py b/data_utils/v2a_utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py deleted file mode 100644 index 245d5e7..0000000 --- a/data_utils/v2a_utils/feature_utils_288.py +++ /dev/null @@ -1,337 +0,0 @@ -""" -PrismAudio feature extraction utilities. - -Implements FeaturesUtils used by scripts/extract_features.py to extract: - - Text features via T5-Gemma (transformers) - - Video features via VideoPrism (JAX/Flax, google-deepmind/videoprism) - - Sync features via Synchformer visual encoder (PyTorch) -""" - -import os -import torch -import torch.nn as nn -import numpy as np - - -class FeaturesUtils: - def __init__(self, vae_config_path=None, synchformer_ckpt=None, device=None): - self.device = device or torch.device("cpu") - self._t5_tokenizer = None - self._t5_encoder = None - self._vp_model = None - self._vp_state = None - self._vp_text_tokenizer = None - self._sync_model = None - - self._synchformer_ckpt = synchformer_ckpt - self._load_synchformer() - - # ------------------------------------------------------------------ - # T5-Gemma text encoding - # ------------------------------------------------------------------ - - def _ensure_t5(self): - if self._t5_encoder is not None: - return - from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - model_id = "google/t5gemma-l-l-ul2-it" - print(f"[FeaturesUtils] Loading T5-Gemma: {model_id}") - self._t5_tokenizer = AutoTokenizer.from_pretrained(model_id) - self._t5_encoder = ( - AutoModelForSeq2SeqLM.from_pretrained(model_id) - .get_encoder() - .to(self.device) - .eval() - ) - - def encode_t5_text(self, texts): - """ - Args: - texts: list of str - Returns: - Tensor [seq_len, 1024] - """ - self._ensure_t5() - tokens = self._t5_tokenizer( - texts, return_tensors="pt", padding=True - ).to(self.device) - with torch.no_grad(): - out = self._t5_encoder(**tokens) - # Move encoder off GPU to save VRAM - self._t5_encoder.to("cpu") - torch.cuda.empty_cache() - return out.last_hidden_state.squeeze(0) # [seq_len, 1024] - - # ------------------------------------------------------------------ - # VideoPrism video + text encoding (JAX) - # ------------------------------------------------------------------ - - def _ensure_videoprism(self): - if self._vp_model is not None: - return - from videoprism import models as vp - import jax - model_name = "videoprism_lvt_public_v1_large" - print(f"[FeaturesUtils] Loading VideoPrism LvT large (1024-dim joint video-text)...") - self._vp_model = vp.get_model(model_name) - self._vp_state = vp.load_pretrained_weights(model_name) - self._vp_text_tokenizer = vp.load_text_tokenizer("c4_en") - jax_dev = jax.devices()[0] - self._jax_forward = jax.jit( - lambda x, y, z: self._vp_model.apply( - self._vp_state, x, y, z, train=False, return_intermediate=True - ), - device=jax_dev, - ) - - def encode_video_and_text_with_videoprism(self, clip_input, texts): - """ - Args: - clip_input: Tensor [1, T, C, H, W] float32, values in [-1, 1] - texts: list of str — CoT captions, passed to VideoPrism LvT text tower - Returns: - global_video_features: Tensor [1, D] - video_features: Tensor [T, D] — per-frame L2-normalized embeddings - global_text_features: Tensor [1, D] - """ - self._ensure_videoprism() - import jax.numpy as jnp - from videoprism import models as vp - - # Normalise from [-1,1] to [0,1] and convert to [B, T, H, W, C] JAX array - frames = clip_input.squeeze(0) # [T, C, H, W] - frames = (frames + 1.0) / 2.0 # [-1,1] → [0,1] - frames = frames.permute(0, 2, 3, 1) # [T, H, W, C] - frames_np = frames.cpu().numpy().astype(np.float32) - frames_jax = jnp.array(frames_np)[None] # [1, T, H, W, C] - - # Tokenize text (padding value 1.0 = pad, 0.0 = real token) - text_ids, text_paddings = vp.tokenize_texts(self._vp_text_tokenizer, texts) - - # Joint video+text forward with intermediate outputs - video_embeddings, text_embeddings, outputs = self._jax_forward( - frames_jax, text_ids, text_paddings - ) - - # Per-frame features: [B, T, 1024] L2-normalized - frame_embed_np = np.array(outputs["frame_embeddings"]) # [1, T, 1024] - per_frame = torch.from_numpy(frame_embed_np[0]).to(self.device) # [T, 1024] - - # Global video embedding: [1024] → [1, 1024] - global_video = torch.from_numpy( - np.array(video_embeddings[0]) - ).unsqueeze(0).to(self.device) # [1, 1024] - - # Global text embedding: [1024] → [1, 1024] - global_text = torch.from_numpy( - np.array(text_embeddings[0]) - ).unsqueeze(0).to(self.device) # [1, 1024] - - return global_video, per_frame, global_text - - # ------------------------------------------------------------------ - # Synchformer sync feature encoding - # ------------------------------------------------------------------ - - def _load_synchformer(self): - if not self._synchformer_ckpt or not os.path.exists(self._synchformer_ckpt): - return - - print(f"[FeaturesUtils] Loading Synchformer from: {self._synchformer_ckpt}") - state = torch.load(self._synchformer_ckpt, map_location="cpu", weights_only=False) - - # Checkpoint may be raw state_dict or wrapped in {"model": ...} - if isinstance(state, dict) and "model" in state: - state_dict = state["model"] - else: - state_dict = state - - self._sync_model = _SynchformerVisualEncoder(state_dict, self.device) - self._sync_model.eval() - - def encode_video_with_sync(self, sync_input): - """ - Args: - sync_input: Tensor [1, T, C, H, W] float32, values in [-1, 1] - Returns: - sync_features: Tensor [num_segments, 768] - """ - if self._sync_model is None: - raise RuntimeError( - "[FeaturesUtils] Synchformer checkpoint not loaded. " - "Pass synchformer_ckpt to FeaturesUtils or set --synchformer_ckpt." - ) - frames = sync_input.squeeze(0).to(self.device) # [T, C, H, W] - with torch.no_grad(): - return self._sync_model(frames) - - -# ------------------------------------------------------------------ -# Synchformer visual encoder — TimeSformer-style ViT-B/16 -# Architecture reverse-engineered from synchformer_state_dict.pth -# ------------------------------------------------------------------ - -import torch.nn.functional as F - - -class _PatchEmbed(nn.Module): - """2D patch embedding: [B, 3, 224, 224] → [B, 196, 768].""" - def __init__(self): - super().__init__() - self.proj = nn.Conv2d(3, 768, kernel_size=16, stride=16) - - def forward(self, x): - return self.proj(x).flatten(2).transpose(1, 2) - - -class _ViTAttn(nn.Module): - """ViT-style QKV attention (timm convention: qkv as single Linear).""" - def __init__(self, dim=768, num_heads=12): - super().__init__() - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.qkv = nn.Linear(dim, dim * 3) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - B, N, D = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - attn = F.softmax((q @ k.transpose(-2, -1)) * self.scale, dim=-1) - return self.proj((attn @ v).transpose(1, 2).reshape(B, N, D)) - - -class _BlockMLP(nn.Module): - """Two-layer MLP with GELU, keys fc1/fc2 to match checkpoint.""" - def __init__(self, dim=768, mlp_dim=3072): - super().__init__() - self.fc1 = nn.Linear(dim, mlp_dim) - self.fc2 = nn.Linear(mlp_dim, dim) - - def forward(self, x): - return self.fc2(F.gelu(self.fc1(x))) - - -class _TimeSformerBlock(nn.Module): - """ - Factorized space-time attention block. - norm1 → spatial attn → norm3 → temporal attn → norm2 → MLP - """ - def __init__(self, dim=768, num_heads=12): - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = _ViTAttn(dim, num_heads) - self.norm3 = nn.LayerNorm(dim) - self.timeattn = _ViTAttn(dim, num_heads) - self.norm2 = nn.LayerNorm(dim) - self.mlp = _BlockMLP(dim) - - def forward(self, x, T): - # x: [T, N, D] (T frames treated as batch, N=197 spatial tokens) - x = x + self.attn(self.norm1(x)) - # Temporal attention: for each spatial position, attend across T frames - # [T, N, D] → [N, T, D] → attend → [N, T, D] → [T, N, D] - xt = x.permute(1, 0, 2) - xt = xt + self.timeattn(self.norm3(xt)) - x = xt.permute(1, 0, 2) - x = x + self.mlp(self.norm2(x)) - return x - - -class _SpatialAttnAgg(nn.Module): - """ - Aggregates 196 spatial patches → 1 feature per frame using a - TransformerEncoderLayer with a learnable CLS token. - Key names match nn.TransformerEncoderLayer: self_attn, linear1, linear2, norm1, norm2. - """ - def __init__(self, dim=768, num_heads=12): - super().__init__() - self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) - self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) - self.linear1 = nn.Linear(dim, dim * 4) - self.linear2 = nn.Linear(dim * 4, dim) - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - def forward(self, x): - # x: [T, 196, 768] — spatial patches (CLS stripped) - T = x.shape[0] - cls = self.cls_token.expand(T, -1, -1) - x = torch.cat([cls, x], dim=1) # [T, 197, 768] - xn = self.norm1(x) - x = x + self.self_attn(xn, xn, xn)[0] - x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) - return x[:, 0, :] # [T, 768] — CLS per frame - - -class _SynchformerVisualEncoder(nn.Module): - """ - TimeSformer-style ViT-B/16 visual encoder for the PrismAudio Synchformer checkpoint. - Processes video in segments of 8 frames → [T_aligned, 768] per-frame features. - """ - - def __init__(self, state_dict, device): - super().__init__() - self.device = device - self.segment_frames = 8 - - self.patch_embed = _PatchEmbed() - self.cls_token = nn.Parameter(torch.zeros(1, 1, 768)) - self.pos_embed = nn.Parameter(torch.zeros(1, 197, 768)) - self.temp_embed = nn.Parameter(torch.zeros(1, 8, 768)) - self.blocks = nn.ModuleList([_TimeSformerBlock() for _ in range(12)]) - self.norm = nn.LayerNorm(768) - self.spatial_attn_agg = _SpatialAttnAgg() - - # Load weights from vfeat_extractor.* prefix - prefix = "vfeat_extractor." - sub = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - # Exclude 3D patch embed (we use 2D only) - sub = {k: v for k, v in sub.items() if not k.startswith("patch_embed_3d")} - missing, unexpected = self.load_state_dict(sub, strict=False) - print(f"[FeaturesUtils] Synchformer loaded — missing={len(missing)}, unexpected={len(unexpected)}") - if missing: - print(f"[FeaturesUtils] missing keys (first 5): {missing[:5]}") - - self.to(device) - - def forward(self, frames): - """ - Args: - frames: [T, C, H, W] float32 in [-1, 1], at 25fps - Returns: - [T_aligned, 768] — per-frame features (T_aligned = floor(T/8)*8) - """ - T = frames.shape[0] - seg = self.segment_frames - num_seg = max(1, T // seg) - T_aligned = num_seg * seg - - results = [] - for i in range(num_seg): - chunk = frames[i * seg:(i + 1) * seg] # [8, C, H, W] - results.append(self._forward_segment(chunk)) - return torch.cat(results, dim=0) # [T_aligned, 768] - - def _forward_segment(self, x): - # x: [8, 3, 224, 224] - T = x.shape[0] # 8 - - # Patch embedding + CLS token - x = self.patch_embed(x) # [8, 196, 768] - cls = self.cls_token.expand(T, -1, -1) - x = torch.cat([cls, x], dim=1) # [8, 197, 768] - - # Positional + temporal embeddings - x = x + self.pos_embed # broadcast (1,197,768) - x = x + self.temp_embed.squeeze(0).unsqueeze(1) # (8,1,768) broadcast - - # Transformer blocks (factorized space-time) - for block in self.blocks: - x = block(x, T) - - x = self.norm(x) - - # Aggregate spatial patches → 1 feature per frame - return self.spatial_attn_agg(x[:, 1:, :]) # [8, 768] diff --git a/docs/plans/2026-03-27-comfyui-prismaudio-design.md b/docs/plans/2026-03-27-comfyui-prismaudio-design.md deleted file mode 100644 index 44ed361..0000000 --- a/docs/plans/2026-03-27-comfyui-prismaudio-design.md +++ /dev/null @@ -1,194 +0,0 @@ -# ComfyUI-PrismAudio Design Document - -**Date:** 2026-03-27 -**Status:** Approved - -## Overview - -ComfyUI nodes for PrismAudio (ICLR 2026) — video-to-audio and text-to-audio generation. PrismAudio uses decomposed Chain-of-Thought reasoning across 4 dimensions (Semantic, Temporal, Aesthetic, Spatial) with a 518M parameter DiT diffusion model and Stable Audio 2.0 VAE. - -## Architecture - -**Approach C: Selective Code Extraction** — Extract only inference-critical code from PrismAudio into a self-contained `prismaudio_core/` module. No JAX/TensorFlow in the ComfyUI environment. Feature extraction via separate isolated environment. - -## Project Structure - -``` -ComfyUI-PrismAudio/ -├── __init__.py # Node registration -├── nodes/ -│ ├── __init__.py -│ ├── model_loader.py # PrismAudioModelLoader -│ ├── feature_loader.py # PrismAudioFeatureLoader (loads .npz) -│ ├── feature_extractor.py # PrismAudioFeatureExtractor (subprocess bridge) -│ ├── sampler.py # PrismAudioSampler -│ ├── text_only.py # PrismAudioTextOnly -│ └── utils.py # Shared helpers -├── prismaudio_core/ # Extracted inference code from PrismAudio -│ ├── __init__.py -│ ├── configs/ -│ │ └── prismaudio.json -│ ├── models/ # DiT, conditioners, autoencoders, etc. -│ ├── inference/ # sampling.py, generation.py -│ └── factory.py # create_model_from_config -├── scripts/ -│ ├── extract_features.py # Standalone VideoPrism feature extraction -│ └── environment.yml # Conda env for extraction (JAX + TF) -├── requirements.txt # PyTorch-only deps (no JAX/TF) -└── README.md -``` - -## Nodes - -### PrismAudioModelLoader - -Loads the diffusion model + VAE. Auto-downloads from HuggingFace if weights not found locally. - -| Field | Type | Details | -|-------|------|---------| -| **Inputs** | | | -| precision | COMBO | [auto, fp32, fp16, bf16] — auto detects GPU capability | -| offload_strategy | COMBO | [auto, keep_in_vram, offload_to_cpu] | -| *(no hf_token widget — security risk, would be saved to workflow JSON)* | | | -| **Output** | | | -| model | PRISMAUDIO_MODEL | Dict containing diffusion model + VAE + config | - -**Token resolution order** (no widget — env/CLI only for security): -1. `HF_TOKEN` environment variable -2. `huggingface-cli login` cached token -3. None — fails on gated models with clear error message linking to license page - -**Auto-download:** Uses `huggingface_hub.hf_hub_download()` from `FunAudioLLM/PrismAudio`. Models stored in `ComfyUI/models/prismaudio/`. Users can also place files manually. - -### PrismAudioFeatureLoader - -Loads pre-computed `.npz` feature files for maximum quality video-to-audio. - -| Field | Type | Details | -|-------|------|---------| -| **Inputs** | | | -| npz_path | STRING | Path to .npz file | -| **Output** | | | -| features | PRISMAUDIO_FEATURES | Dict with video_features, global_video_features, text_features, global_text_features, sync_features | - -### PrismAudioFeatureExtractor - -Subprocess bridge — extracts features from video using VideoPrism in an isolated environment. - -| Field | Type | Details | -|-------|------|---------| -| **Inputs** | | | -| video | IMAGE | ComfyUI video frames tensor | -| caption_cot | STRING | CoT description text | -| python_env | STRING | Path to python binary with JAX/TF (default: "python") | -| output_dir | STRING | Cache directory for .npz files (default: temp dir) | -| **Output** | | | -| features | PRISMAUDIO_FEATURES | Same format as FeatureLoader output | - -**Caching:** Hashes video + text to avoid re-extraction on repeated runs. - -### PrismAudioSampler - -Main generation node — takes model + features, produces audio. - -| Field | Type | Details | -|-------|------|---------| -| **Inputs** | | | -| model | PRISMAUDIO_MODEL | From ModelLoader | -| features | PRISMAUDIO_FEATURES | From FeatureLoader or FeatureExtractor | -| cot_description | STRING | Multiline CoT text | -| duration | FLOAT | 1.0-30.0, defaults to video length | -| steps | INT | 1-100, default 24 | -| cfg_scale | FLOAT | 1.0-20.0, default 5.0 | -| seed | INT | Controls noise generation | -| **Output** | | | -| audio | AUDIO | {waveform: tensor, sample_rate: 44100} | - -**Pipeline:** -1. Encode CoT text via T5-Gemma -> text_features -2. Assemble conditioning (cross_attn_cond, add_cond, sync_cond) -3. Compute latent_seq_len = round(44100 / 2048 * duration) -4. Generate noise [1, 64, latent_seq_len] from seed -5. Discrete Euler sampling (rectified flow) with CFG -6. VAE decode -> stereo waveform at 44100 Hz -7. Normalize to [-1, 1], return as AUDIO - -### PrismAudioTextOnly - -Text-to-audio without video input. - -| Field | Type | Details | -|-------|------|---------| -| **Inputs** | | | -| model | PRISMAUDIO_MODEL | From ModelLoader | -| text_prompt | STRING | Text description | -| duration | FLOAT | 1.0-30.0 | -| steps | INT | 1-100, default 24 | -| cfg_scale | FLOAT | 1.0-20.0, default 5.0 | -| seed | INT | Controls noise generation | -| **Output** | | | -| audio | AUDIO | {waveform: tensor, sample_rate: 44100} | - -Uses empty tensors for video/sync features, T5-Gemma encodes the text prompt. - -## VRAM Management - -Adaptive strategy using `comfy.model_management`: - -| Available VRAM | Behavior | -|---|---| -| 24GB+ | Keep diffusion + VAE in VRAM | -| 12-24GB | Sequential offload between stages | -| 8-12GB | Aggressive offload, one component on GPU at a time, fp16 forced | -| <8GB | Warn user, attempt with aggressive offload + fp16 | - -Key APIs: `mm.get_torch_device()`, `mm.get_free_memory()`, `mm.soft_empty_cache()`, `mm.unet_offload_device()` - -## Feature Extraction Paths - -### Path 1: Pre-computed .npz (FeatureLoader) -User runs `scripts/extract_features.py` externally in the extraction conda env. Loads result into ComfyUI. Original VideoPrism quality, zero ComfyUI env risk. - -### Path 2: Subprocess bridge (FeatureExtractor) -Node calls extraction script as subprocess using a user-specified Python binary. Seamless in-ComfyUI experience, JAX runs isolated. Caches results by content hash. - -### Path 3: Text-only (TextOnly node) -No video features needed. T5-Gemma text encoding only (PyTorch-native). - -## Dependencies - -### ComfyUI environment (`requirements.txt`) -``` -einops>=0.7.0 -safetensors -huggingface_hub -transformers>=4.52.3 -k-diffusion>=0.1.1 -``` - -flash-attn: Optional, detected at runtime. Falls back to `torch.nn.functional.scaled_dot_product_attention`. - -### Extraction environment (`scripts/environment.yml`) -Separate conda environment with JAX, tensorflow-cpu==2.15.0, VideoPrism, Synchformer, decord. Provided as ready-made conda env file for one-command setup. - -## Model Files - -Stored in `ComfyUI/models/prismaudio/`: - -| File | Size | Source | -|------|------|--------| -| prismaudio.ckpt | ~2GB | FunAudioLLM/PrismAudio | -| vae.ckpt | ~2.5GB | FunAudioLLM/PrismAudio | -| synchformer_state_dict.pth | ~950MB | FunAudioLLM/PrismAudio | - -T5-Gemma (`google/t5gemma-l-l-ul2-it`) cached in standard HuggingFace cache. - -Registered via: `folder_paths.add_model_folder_path("prismaudio", ...)` - -## Design Decisions - -- **Composable**: Standard AUDIO output, CoT as plain STRING input. No reinventing save/preview/mux nodes. -- **No JAX/TF in ComfyUI env**: All JAX-dependent code isolated in extraction script/env. -- **LLM-agnostic CoT**: Users bring their own CoT generation via existing LLM nodes — better models available than bundled Qwen2.5-VL. -- **HF token via env/CLI only**: No widget (ComfyUI saves all STRING values to workflow JSON). Uses `HF_TOKEN` env var or `huggingface-cli login`. -- **flash-attn optional**: Avoids installation headaches, uses PyTorch SDPA as fallback. diff --git a/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md b/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md deleted file mode 100644 index c1e5daf..0000000 --- a/docs/plans/2026-03-27-comfyui-prismaudio-implementation.md +++ /dev/null @@ -1,1372 +0,0 @@ -# ComfyUI-PrismAudio Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Build ComfyUI custom nodes for PrismAudio video-to-audio and text-to-audio generation with adaptive VRAM management and isolated feature extraction. - -**Architecture:** Selective code extraction from PrismAudio `prismaudio` branch into `prismaudio_core/` module. 5 ComfyUI nodes (ModelLoader, FeatureLoader, FeatureExtractor, Sampler, TextOnly). Feature extraction via subprocess bridge to isolated JAX/TF environment. Auto-download from HuggingFace with gated model support. - -**Tech Stack:** PyTorch, ComfyUI APIs (folder_paths, comfy.model_management, comfy.utils), HuggingFace Hub, transformers (T5-Gemma), einops, k-diffusion, safetensors - ---- - -## Bug Fixes Applied (from review) - -This plan incorporates fixes for all 14 bugs identified during review: - -1. **sample_discrete_euler callback**: Copy function into prismaudio_core, add callback param to the sampling loop -2. **Metadata format**: Return `(dict,)` tuple, not flat dict — matches `MultiConditioner.forward(batch_metadata: List[Dict])` -3. **video_exist**: Use `torch.tensor(True/False)`, not Python bool -4. **None features**: Use zero tensors of correct shape, never None — `pad_sequence(None)` crashes -5. **update_seq_lengths removed**: Does not exist in source. Model adapts to input shapes dynamically — no seq length config needed -6. **Sequence length config**: Not needed — model handles variable lengths natively via input tensor shapes -7. **T5-Gemma class**: Use `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` -8. **Peak normalization**: Add `audio / audio.abs().max().clamp(min=1e-8)` before clamp -9. **Empty feature substitution**: Match reference approach — substitute on raw conditioning output with correct shapes -10. **hf_token security**: Remove STRING widget entirely. Rely on env var / `huggingface-cli login` only. Document in README -11. **Synchformer size**: Corrected to ~950MB in docs -12. **T5 truncation**: Match reference — `truncation=False`, no max_length -13. **Remove global_video/text_features from metadata**: Not consumed by any conditioner -14. **Add tqdm to requirements** - -### Bug Fixes Applied (from second review) - -15. **Sync_MLP zero-tensor crash**: Sync zero-tensor fallback must be `[8, 768]` not `[1, 768]` — Sync_MLP does `length // 8` which gives 0 for length=1, causing `F.interpolate` on empty tensor -16. **sample_discrete_euler undefined `i`**: Loop needs `enumerate()` — `for i, (t_curr, t_prev) in enumerate(zip(...))` -17. **_update_seq_lengths removed entirely**: Was a no-op (attributes don't exist on DiT). Model handles variable lengths natively — function deleted -18. **cot_description removed from Sampler**: Was dead code — features already contain pre-computed text_features -19. **Conditioner VRAM leak**: Add `diffusion.conditioner.to(get_offload_device())` after generation in offload path -20. **VAE size corrected**: ~2.52GB, not ~300MB - -### Bug Fixes Applied (from third review) - -21. **Remove video_features substitution**: `_substitute_empty_features` should only substitute sync_features. Reference code checks for `metaclip_features` (wrong key for prismaudio config), so video substitution never runs. Cond_MLP with zero input + bias-free linears naturally produces near-zero output -22. **Remove dead `sample()` and `sample_rf`**: Wrong noise schedule (linear vs cosine), never called for rectified_flow. Only keep `sample_discrete_euler` -23. **VAE decode in fp32**: Keep pretransform in fp32 even when rest of model is fp16/bf16 — snake activations overflow in fp16 -24. **Lazy imports in nodes/__init__.py**: Use try/except to allow incremental development -25. **MPS Generator guard**: `torch.Generator(device="cpu")` on Apple Silicon, move noise to device after -26. **Use comfy.utils.load_torch_file for VAE**: Consistent with diffusion loading, handles PyTorch 2.6+ weights_only default -27. **Task 10 stale reference**: Remove mention of `_update_seq_lengths` - -### Bug Fixes Applied (from fourth review) - -28. **TextOnly missing MPS guard**: Fix-on-fix regression — MPS Generator guard was applied to Sampler but not TextOnly -29. **TextOnly noise dtype**: Was passing dtype to torch.randn directly (fp16 noise), now generates fp32 then converts (matching Sampler) -30. **Sync substitution seq length**: Low-severity divergence from reference, accepted (DiT handles variable-length sync_cond) - ---- - -### Task 1: Project Scaffolding - -**Files:** -- Create: `__init__.py` -- Create: `nodes/__init__.py` -- Create: `nodes/utils.py` -- Create: `requirements.txt` - -**Step 1: Create requirements.txt** - -``` -einops>=0.7.0 -safetensors -huggingface_hub -transformers>=4.52.3 -k-diffusion>=0.1.1 -alias-free-torch -descript-audio-codec -tqdm -``` - -**Step 2: Create nodes/utils.py with shared helpers** - -```python -import os -import torch -import folder_paths -import comfy.model_management as mm - -PRISMAUDIO_CATEGORY = "PrismAudio" -SAMPLE_RATE = 44100 -DOWNSAMPLING_RATIO = 2048 -IO_CHANNELS = 64 - -def get_prismaudio_model_dir(): - """Get or create the prismaudio model directory.""" - model_dir = os.path.join(folder_paths.models_dir, "prismaudio") - os.makedirs(model_dir, exist_ok=True) - return model_dir - -def register_model_folder(): - """Register prismaudio model folder with ComfyUI.""" - model_dir = get_prismaudio_model_dir() - folder_paths.add_model_folder_path("prismaudio", model_dir) - -def get_device(): - return mm.get_torch_device() - -def get_offload_device(): - return mm.unet_offload_device() - -def get_free_memory(device=None): - if device is None: - device = get_device() - return mm.get_free_memory(device) - -def soft_empty_cache(): - mm.soft_empty_cache() - -def determine_precision(preference, device): - """Determine the best precision for the given device.""" - if preference != "auto": - return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference] - if device.type == "cpu": - return torch.float32 - if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): - return torch.bfloat16 - return torch.float16 - -def determine_offload_strategy(preference): - """Determine offload strategy based on available VRAM.""" - if preference != "auto": - return preference - free_mem = get_free_memory() - gb = free_mem / (1024 ** 3) - if gb >= 24: - return "keep_in_vram" - else: - return "offload_to_cpu" - -def try_import_flash_attn(): - """Try to import flash attention, return None if unavailable.""" - try: - import flash_attn - return flash_attn - except ImportError: - return None - -def resolve_hf_token(): - """Resolve HF token from env var or cached login. No widget — security risk.""" - env_token = os.environ.get("HF_TOKEN") - if env_token: - return env_token - # huggingface_hub will use cached token automatically if None is passed - return None -``` - -**Step 3: Create nodes/__init__.py** - -```python -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} - -# Lazy imports — allows incremental development (nodes can be added one at a time) -_NODES = { - "PrismAudioModelLoader": (".model_loader", "PrismAudioModelLoader", "PrismAudio Model Loader"), - "PrismAudioFeatureLoader": (".feature_loader", "PrismAudioFeatureLoader", "PrismAudio Feature Loader"), - "PrismAudioFeatureExtractor": (".feature_extractor", "PrismAudioFeatureExtractor", "PrismAudio Feature Extractor"), - "PrismAudioSampler": (".sampler", "PrismAudioSampler", "PrismAudio Sampler"), - "PrismAudioTextOnly": (".text_only", "PrismAudioTextOnly", "PrismAudio Text Only"), -} - -for key, (module_path, class_name, display_name) in _NODES.items(): - try: - import importlib - mod = importlib.import_module(module_path, package=__name__) - NODE_CLASS_MAPPINGS[key] = getattr(mod, class_name) - NODE_DISPLAY_NAME_MAPPINGS[key] = display_name - except (ImportError, AttributeError) as e: - print(f"[PrismAudio] Skipping {key}: {e}") -``` - -**Step 4: Create top-level __init__.py** - -```python -""" -ComfyUI-PrismAudio: Video-to-Audio and Text-to-Audio generation using PrismAudio (ICLR 2026). -""" -from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS - -__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -``` - -**Step 5: Commit** - -```bash -git add __init__.py nodes/__init__.py nodes/utils.py requirements.txt -git commit -m "feat: project scaffolding with shared utils and node registration" -``` - ---- - -### Task 2: Extract prismaudio_core — Model Config + Factory - -**Files:** -- Create: `prismaudio_core/__init__.py` -- Create: `prismaudio_core/configs/prismaudio.json` (copy from PrismAudio repo) -- Create: `prismaudio_core/factory.py` (adapted from `PrismAudio/models/factory.py`) - -**Step 1: Create prismaudio_core/__init__.py** - -```python -""" -PrismAudio core inference modules. -Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch). -Only inference-critical code — no training, no JAX/TF dependencies. -""" -``` - -**Step 2: Copy prismaudio.json config** - -Fetch from `https://raw.githubusercontent.com/FunAudioLLM/ThinkSound/prismaudio/PrismAudio/configs/model_configs/prismaudio.json` and save to `prismaudio_core/configs/prismaudio.json`. This is a JSON config file with no code — copy verbatim. - -**Step 3: Create factory.py** - -Extract from `PrismAudio/models/factory.py`. Keep only these functions (remove training-related code): -- `create_model_from_config(model_config)` — entry point -- `create_diffusion_cond_from_config(config)` — creates the full model -- `create_pretransform_from_config(pretransform_config, sample_rate)` — VAE -- `create_autoencoder_from_config(config)` — AudioAutoencoder -- `create_bottleneck_from_config(config)` — VAEBottleneck -- `create_multi_conditioner_from_conditioning_config(config)` — conditioners - -All imports should reference `prismaudio_core.models.*` instead of `PrismAudio.models.*`. - -**Step 4: Commit** - -```bash -git add prismaudio_core/ -git commit -m "feat: extract prismaudio_core config and model factory" -``` - ---- - -### Task 3: Extract prismaudio_core — Model Modules - -**Files:** -- Create: `prismaudio_core/models/__init__.py` -- Create: `prismaudio_core/models/dit.py` (from `PrismAudio/models/dit.py`) -- Create: `prismaudio_core/models/diffusion.py` (from `PrismAudio/models/diffusion.py`) -- Create: `prismaudio_core/models/conditioners.py` (from `PrismAudio/models/conditioners.py`) -- Create: `prismaudio_core/models/autoencoders.py` (from `PrismAudio/models/autoencoders.py`) -- Create: `prismaudio_core/models/pretransforms.py` (from `PrismAudio/models/pretransforms.py`) -- Create: `prismaudio_core/models/blocks.py` (from `PrismAudio/models/blocks.py`) -- Create: `prismaudio_core/models/utils.py` (from `PrismAudio/models/utils.py`) -- Create: `prismaudio_core/models/bottleneck.py` (from `PrismAudio/models/bottleneck.py`) -- Create: `prismaudio_core/models/transformer.py` (from `PrismAudio/models/transformer.py`) -- Create: `prismaudio_core/models/local_attention.py` (if used by transformer) - -**Step 1: Extract model files** - -For each file, extract from the PrismAudio repo. Key modifications: -- Change all internal imports from `PrismAudio.models.*` to `prismaudio_core.models.*` -- Remove training-only code (loss functions, training step methods, gradient checkpointing setup) -- Keep all inference paths intact - -**Critical classes to preserve:** - -From `dit.py`: -- `DiffusionTransformer` — full class with `forward()`, CFG logic, conditioning assembly -- `FourierFeatures` — timestep embedding -- Keep `empty_clip_feat` and `empty_sync_feat` learned parameters (nn.Parameter, zeros) - -From `diffusion.py`: -- `ConditionedDiffusionModelWrapper` — with `get_conditioning_inputs()` and routing logic -- `DiTWrapper` — thin wrapper that passes all kwargs through -- `create_diffusion_cond_from_config()` — factory function - -From `conditioners.py`: -- `Cond_MLP` (type `"cond_mlp"`) — for video_features and text_features. Uses `pad_sequence`, 2-layer MLP, returns `[embeddings, ones_mask]`. During eval with batch<16, doubles batch with null embed for CFG -- `Sync_MLP` (type `"sync_mlp"`) — for sync_features with learnable `sync_pos_emb` of shape (1,1,8,dim), reshapes into segments of 8, interpolates to target length -- `MultiConditioner` — iterates over `batch_metadata: List[Dict]`, collects per-sample inputs, calls each conditioner. Returns dict of `{key: (tensor, mask)}` -- `create_multi_conditioner_from_conditioning_config()` — factory - -From `autoencoders.py`: -- `AudioAutoencoder` — with `encode_audio()` and `decode_audio()` -- `OobleckEncoder`, `OobleckDecoder` — with ResidualUnit, snake activation -- Dependencies: `alias_free_torch` (SnakeBeta), `dac.nn` (WNConv1d, WNConvTranspose1d) - -From `pretransforms.py`: -- `AutoencoderPretransform` — wraps AudioAutoencoder, `encode()` and `decode()` methods - -From `bottleneck.py`: -- `VAEBottleneck` — reparameterization trick (split mean/logvar, sample) - -From `blocks.py`: -- Any shared blocks used by the above (attention blocks, FeedForward, etc.) - -From `transformer.py`: -- `ContinuousTransformer` — the core transformer with cross-attention, used by DiffusionTransformer - -From `utils.py`: -- `load_ckpt_state_dict()` — handles .safetensors and .ckpt, optional prefix stripping -- `remove_weight_norm_from_model()` — used in some inference paths - -**Step 2: Handle flash-attn gracefully in transformer.py / blocks.py** - -Replace hard `import flash_attn` with: -```python -try: - from flash_attn import flash_attn_func - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False -``` - -In the attention forward pass, use: -```python -if HAS_FLASH_ATTN: - out = flash_attn_func(q, k, v, ...) -else: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, ...) -``` - -**Step 3: Verify imports resolve** - -Run: `python -c "from prismaudio_core.factory import create_model_from_config; print('OK')"` from the project root (with ComfyUI's python). - -Expected: `OK` (or import errors to fix iteratively) - -**Step 4: Commit** - -```bash -git add prismaudio_core/models/ -git commit -m "feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)" -``` - ---- - -### Task 4: Extract prismaudio_core — Inference/Sampling (with callback fix) - -**Files:** -- Create: `prismaudio_core/inference/__init__.py` -- Create: `prismaudio_core/inference/sampling.py` (MODIFIED from `PrismAudio/inference/sampling.py`) -- Create: `prismaudio_core/inference/utils.py` (from `PrismAudio/inference/utils.py`) - -**Step 1: Extract sampling.py WITH callback support added** - -The original `sample_discrete_euler` uses `tqdm` and has no callback parameter. -We MUST copy and modify it to add callback support for ComfyUI progress bars. - -```python -import torch -from tqdm import trange - -def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args): - """Discrete Euler sampler for rectified flow, with optional callback. - - Modified from PrismAudio to add callback parameter for ComfyUI progress reporting. - Original uses tqdm internally. - - Args: - model: The diffusion model (DiTWrapper) - x: Initial noise tensor [B, C, T] - steps: Number of sampling steps - sigma_max: Maximum sigma (default 1.0 for rectified flow) - callback: Optional callable({"i": step, "x": current_x}) for progress - **extra_args: Passed to model() — includes cross_attn_cond, add_cond, - sync_cond, cfg_scale, batch_cfg, etc. - """ - t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype) - - for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])): - dt = t_next - t_curr - t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device) - x = x + dt * model(x, t_curr_tensor, **extra_args) - if callback is not None: - callback({"i": i, "x": x}) - - return x - - - -# Note: sample_rf() and sample() (v-diffusion) are NOT included. -# PrismAudio uses rectified_flow objective which only needs sample_discrete_euler. -# Including unused samplers with potentially wrong math is a maintenance hazard. -``` - -**Step 2: Extract inference/utils.py** - -Keep: -- `set_audio_channels(audio, target_channels)` -- `prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device)` - -**Step 3: Verify sampling import** - -Run: `python -c "from prismaudio_core.inference.sampling import sample_discrete_euler; print('OK')"` - -Expected: `OK` - -**Step 4: Commit** - -```bash -git add prismaudio_core/inference/ -git commit -m "feat: extract prismaudio_core inference with callback-enabled sampling" -``` - ---- - -### Task 5: PrismAudioModelLoader Node - -**Files:** -- Create: `nodes/model_loader.py` - -**Step 1: Write the node** - -Key design decisions: -- No hf_token widget (security risk — saved to workflow JSON). Uses env var / cached login only. -- Creates model with default config. Duration-dependent seq lengths handled at sample time. -- The model config's `sample_size: 397312` corresponds to ~9s default. For other durations, - the Sampler node will update seq lengths on the DiT before each generation. - -```python -import os -import json -import torch -import folder_paths -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder, - get_device, get_offload_device, determine_precision, determine_offload_strategy, - soft_empty_cache, resolve_hf_token, -) - -# HuggingFace repo for auto-download -HF_REPO_ID = "FunAudioLLM/PrismAudio" -REQUIRED_FILES = { - "diffusion": "prismaudio.ckpt", - "vae": "vae.ckpt", - "synchformer": "synchformer_state_dict.pth", -} - - -def _download_if_missing(filename, model_dir, hf_token=None): - """Download a model file from HuggingFace if not present locally.""" - filepath = os.path.join(model_dir, filename) - if os.path.exists(filepath): - return filepath - - from huggingface_hub import hf_hub_download - print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...") - try: - downloaded = hf_hub_download( - repo_id=HF_REPO_ID, - filename=filename, - local_dir=model_dir, - token=hf_token or None, - ) - return downloaded - except Exception as e: - if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower(): - raise RuntimeError( - f"[PrismAudio] Model '{filename}' requires license acceptance. " - f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, " - f"then set HF_TOKEN env var or run: huggingface-cli login" - ) from e - raise - - -class PrismAudioModelLoader: - @classmethod - def INPUT_TYPES(cls): - register_model_folder() - return { - "required": { - "precision": (["auto", "fp32", "fp16", "bf16"],), - "offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_MODEL",) - RETURN_NAMES = ("model",) - FUNCTION = "load_model" - CATEGORY = PRISMAUDIO_CATEGORY - - def load_model(self, precision, offload_strategy): - device = get_device() - dtype = determine_precision(precision, device) - strategy = determine_offload_strategy(offload_strategy) - token = resolve_hf_token() - model_dir = get_prismaudio_model_dir() - - # Auto-download missing files - for key, filename in REQUIRED_FILES.items(): - _download_if_missing(filename, model_dir, hf_token=token) - - # Load config - config_path = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "prismaudio_core", "configs", "prismaudio.json" - ) - with open(config_path) as f: - model_config = json.load(f) - - # Create model from config - from prismaudio_core.factory import create_model_from_config - model = create_model_from_config(model_config) - - # Load diffusion weights - diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"]) - diffusion_state = comfy.utils.load_torch_file(diffusion_path) - # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} - if "state_dict" in diffusion_state: - diffusion_state = diffusion_state["state_dict"] - model.load_state_dict(diffusion_state, strict=False) - - # Load VAE weights separately - # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat - vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) - vae_full_state = comfy.utils.load_torch_file(vae_path) - # Strip "autoencoder." prefix from keys - vae_state = {} - prefix = "autoencoder." - for k, v in vae_full_state.items(): - if k.startswith(prefix): - vae_state[k[len(prefix):]] = v - else: - vae_state[k] = v - model.pretransform.load_state_dict(vae_state) - - # Apply precision: DiT + conditioners in user-selected dtype, - # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 - model.model.to(dtype) # DiTWrapper - model.conditioner.to(dtype) # MultiConditioner - # model.pretransform stays in fp32 - - if strategy == "keep_in_vram": - model = model.to(device) - else: - model = model.to(get_offload_device()) - - model.eval() - - return ({ - "model": model, - "dtype": dtype, - "strategy": strategy, - "config": model_config, - "model_dir": model_dir, - },) -``` - -**Step 2: Test that ComfyUI discovers the node** - -Run ComfyUI and check that "PrismAudio Model Loader" appears in the node list. - -**Step 3: Commit** - -```bash -git add nodes/model_loader.py -git commit -m "feat: PrismAudioModelLoader node with auto-download and adaptive VRAM" -``` - ---- - -### Task 6: PrismAudioFeatureLoader Node - -**Files:** -- Create: `nodes/feature_loader.py` - -**Step 1: Write the node** - -```python -import os -import numpy as np -import torch -from .utils import PRISMAUDIO_CATEGORY - -# Keys consumed by the conditioners (video_features, text_features, sync_features) -# global_video_features and global_text_features are NOT consumed by any conditioner -# in the prismaudio.json config — they are unused. -REQUIRED_KEYS = [ - "video_features", - "text_features", - "sync_features", -] - - -class PrismAudioFeatureLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_FEATURES",) - RETURN_NAMES = ("features",) - FUNCTION = "load_features" - CATEGORY = PRISMAUDIO_CATEGORY - - def load_features(self, npz_path): - if not os.path.exists(npz_path): - raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}") - - data = np.load(npz_path, allow_pickle=True) - - features = {} - for key in REQUIRED_KEYS: - if key in data: - features[key] = torch.from_numpy(data[key]).float() - else: - print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros") - # Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None - # Sync_MLP requires length divisible by 8 (segments of 8 frames) - if key == "sync_features": - features[key] = torch.zeros(8, 768) - else: - features[key] = torch.zeros(1, 1024) - - # Load duration if present - if "duration" in data: - features["duration"] = float(data["duration"]) - - return (features,) -``` - -**Step 2: Commit** - -```bash -git add nodes/feature_loader.py -git commit -m "feat: PrismAudioFeatureLoader node for pre-computed .npz files" -``` - ---- - -### Task 7: PrismAudioFeatureExtractor Node (Subprocess Bridge) - -**Files:** -- Create: `nodes/feature_extractor.py` -- Create: `scripts/extract_features.py` -- Create: `scripts/environment.yml` - -**Step 1: Create the conda environment.yml** - -```yaml -name: prismaudio-extract -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - ffmpeg<7 - - pip: - - torch>=2.6.0 - - torchaudio>=2.6.0 - - torchvision>=0.21.0 - - tensorflow-cpu==2.15.0 - - jax - - jaxlib - - transformers>=4.52.3 - - decord - - einops>=0.7.0 - - numpy - - mediapy - - git+https://github.com/google-deepmind/videoprism.git -``` - -**Step 2: Create scripts/extract_features.py** - -This is a standalone script that: -1. Takes `--video`, `--cot_text`, `--output` arguments -2. Loads VideoPrism, T5-Gemma, Synchformer -3. Extracts features from the video -4. Saves as `.npz` - -```python -#!/usr/bin/env python3 -""" -Standalone PrismAudio feature extraction script. -Run in a separate conda env with JAX/TF installed. - -Usage: - python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz - -Setup: - conda env create -f environment.yml - conda activate prismaudio-extract -""" - -import argparse -import os -import sys -import numpy as np -import torch - - -def main(): - parser = argparse.ArgumentParser(description="PrismAudio feature extraction") - parser.add_argument("--video", required=True, help="Path to input video") - parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") - parser.add_argument("--output", required=True, help="Output .npz path") - parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") - parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") - parser.add_argument("--clip_fps", type=float, default=4.0) - parser.add_argument("--clip_size", type=int, default=288) - parser.add_argument("--sync_fps", type=float, default=25.0) - parser.add_argument("--sync_size", type=int, default=224) - args = parser.parse_args() - - if not os.path.exists(args.video): - print(f"Error: Video not found: {args.video}") - sys.exit(1) - - # Import feature extraction utils (requires JAX/TF) - from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils - import torchvision.transforms as T - from decord import VideoReader, cpu - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Initialize feature extractor - feat_utils = FeaturesUtils( - vae_config_path=args.vae_config, - synchformer_ckpt=args.synchformer_ckpt, - device=device, - ) - - # Load and preprocess video - vr = VideoReader(args.video, ctx=cpu(0)) - fps = vr.get_avg_fps() - total_frames = len(vr) - duration = total_frames / fps - - # Extract CLIP frames (4fps, 288x288) - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] - clip_indices = [min(i, total_frames - 1) for i in clip_indices] - clip_frames = vr.get_batch(clip_indices).asnumpy() - - clip_transform = T.Compose([ - T.ToPILImage(), - T.Resize(args.clip_size), - T.CenterCrop(args.clip_size), - T.ToTensor(), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) - - # Extract Sync frames (25fps, 224x224) - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] - sync_indices = [min(i, total_frames - 1) for i in sync_indices] - sync_frames = vr.get_batch(sync_indices).asnumpy() - - sync_transform = T.Compose([ - T.ToPILImage(), - T.Resize(args.sync_size), - T.CenterCrop(args.sync_size), - T.ToTensor(), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) - - # Extract features - print("[PrismAudio] Encoding text with T5-Gemma...") - text_features = feat_utils.encode_t5_text([args.cot_text]) - - print("[PrismAudio] Encoding video with VideoPrism...") - global_video_features, video_features, global_text_features = \ - feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) - - print("[PrismAudio] Encoding video with Synchformer...") - sync_features = feat_utils.encode_video_with_sync(sync_input) - - # Save as .npz - np.savez( - args.output, - video_features=video_features.cpu().numpy(), - global_video_features=global_video_features.cpu().numpy(), - text_features=text_features.cpu().numpy(), - global_text_features=global_text_features.cpu().numpy(), - sync_features=sync_features.cpu().numpy(), - caption_cot=args.cot_text, - duration=duration, - ) - print(f"[PrismAudio] Features saved to {args.output}") - - -if __name__ == "__main__": - main() -``` - -**Step 3: Create the feature extractor node** - -```python -import os -import hashlib -import subprocess -import tempfile -import torch - -from .utils import PRISMAUDIO_CATEGORY -from .feature_loader import PrismAudioFeatureLoader - - -def _hash_inputs(video_tensor, cot_text): - """Create a hash of the inputs for caching.""" - h = hashlib.sha256() - h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed - h.update(cot_text.encode()) - return h.hexdigest()[:16] - - -def _save_video_tensor_to_mp4(video_tensor, output_path, fps=30): - """Save ComfyUI IMAGE tensor [T,H,W,C] to MP4.""" - import torchvision.io as tvio - # ComfyUI IMAGE is [T,H,W,C] float32 [0,1] - frames = (video_tensor * 255).to(torch.uint8) - # torchvision write_video expects [T,H,W,C] uint8 - tvio.write_video(output_path, frames, fps=fps) - - -class PrismAudioFeatureExtractor: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": ("IMAGE",), - "caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}), - }, - "optional": { - "python_env": ("STRING", {"default": "python", "tooltip": "Path to python binary with JAX/TF (e.g., /path/to/conda/envs/prismaudio-extract/bin/python)"}), - "cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}), - "synchformer_ckpt": ("STRING", {"default": "", "tooltip": "Path to synchformer checkpoint (auto-resolved if empty)"}), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_FEATURES",) - RETURN_NAMES = ("features",) - FUNCTION = "extract_features" - CATEGORY = PRISMAUDIO_CATEGORY - - def extract_features(self, video, caption_cot, python_env="python", cache_dir="", synchformer_ckpt=""): - # Determine cache directory - if not cache_dir: - cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features") - os.makedirs(cache_dir, exist_ok=True) - - # Check cache - cache_hash = _hash_inputs(video, caption_cot) - cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") - if os.path.exists(cached_path): - print(f"[PrismAudio] Using cached features: {cached_path}") - loader = PrismAudioFeatureLoader() - return loader.load_features(cached_path) - - # Save video to temp file - with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: - tmp_video = tmp.name - _save_video_tensor_to_mp4(video, tmp_video) - - # Build subprocess command - script_path = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "scripts", "extract_features.py" - ) - - cmd = [ - python_env, - script_path, - "--video", tmp_video, - "--cot_text", caption_cot, - "--output", cached_path, - ] - if synchformer_ckpt: - cmd.extend(["--synchformer_ckpt", synchformer_ckpt]) - - print(f"[PrismAudio] Extracting features via subprocess...") - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=600, # 10 minute timeout - ) - if result.returncode != 0: - raise RuntimeError( - f"[PrismAudio] Feature extraction failed:\n{result.stderr}" - ) - print(result.stdout) - finally: - if os.path.exists(tmp_video): - os.unlink(tmp_video) - - # Load the extracted features - loader = PrismAudioFeatureLoader() - return loader.load_features(cached_path) -``` - -**Step 4: Commit** - -```bash -git add nodes/feature_extractor.py scripts/extract_features.py scripts/environment.yml -git commit -m "feat: PrismAudioFeatureExtractor node with subprocess bridge and conda env" -``` - ---- - -### Task 8: PrismAudioSampler Node - -**Files:** -- Create: `nodes/sampler.py` - -**Step 1: Write the sampler node** - -This is the core node. Key fixes from review: -- Metadata is a TUPLE of dicts, not a flat dict -- video_exist is torch.tensor, not Python bool -- Empty features are zero tensors, not None -- Peak normalization before clamp -- Sequence lengths set on DiT config before sampling (matching predict.py approach) -- No callback kwarg forwarded to model — callback is handled by our modified sample_discrete_euler - -```python -import torch -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, - get_device, get_offload_device, soft_empty_cache, -) - - -class PrismAudioSampler: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("PRISMAUDIO_MODEL",), - "features": ("PRISMAUDIO_FEATURES",), - "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds"}), - "steps": ("INT", {"default": 24, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}), - "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), - }, - } - - RETURN_TYPES = ("AUDIO",) - RETURN_NAMES = ("audio",) - FUNCTION = "generate" - CATEGORY = PRISMAUDIO_CATEGORY - - def generate(self, model, features, duration, steps, cfg_scale, seed): - device = get_device() - dtype = model["dtype"] - strategy = model["strategy"] - diffusion = model["model"] - - # Compute latent dimensions - latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) - - # Note: no seq length config needed — the model adapts to input tensor shapes - # dynamically via its transformer architecture. - - # Determine if video features are present (not all zeros) - has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 - - # Build metadata as a TUPLE of dicts (one per batch sample) - # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this - sample_meta = { - "video_features": features["video_features"].to(device, dtype=dtype), - "text_features": features["text_features"].to(device, dtype=dtype), - "sync_features": features["sync_features"].to(device, dtype=dtype), - "video_exist": torch.tensor(has_video), - } - metadata = (sample_meta,) - - # Move model to device if offloaded - if strategy == "offload_to_cpu": - diffusion.model.to(device) - diffusion.conditioner.to(device) - soft_empty_cache() - - with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): - # Run conditioning - conditioning = diffusion.conditioner(metadata, device) - - # Handle missing video: substitute learned empty embeddings - if not has_video: - _substitute_empty_features(diffusion, conditioning, device, dtype) - - # Assemble conditioning inputs for the DiT - cond_inputs = diffusion.get_conditioning_inputs(conditioning) - - # Generate noise from seed (MPS doesn't support torch.Generator) - gen_device = "cpu" if device.type == "mps" else device - generator = torch.Generator(device=gen_device).manual_seed(seed) - noise = torch.randn( - [1, IO_CHANNELS, latent_length], - generator=generator, - device=gen_device, - ).to(device=device, dtype=dtype) - - # Sample with progress bar - pbar = comfy.utils.ProgressBar(steps) - - from prismaudio_core.inference.sampling import sample_discrete_euler - - def on_step(info): - pbar.update(1) - - fakes = sample_discrete_euler( - diffusion.model, - noise, - steps, - callback=on_step, - **cond_inputs, - cfg_scale=cfg_scale, - batch_cfg=True, - ) - - # Offload diffusion model and conditioner before VAE decode - if strategy == "offload_to_cpu": - diffusion.model.to(get_offload_device()) - diffusion.conditioner.to(get_offload_device()) - soft_empty_cache() - diffusion.pretransform.to(device) - - # VAE decode in fp32 (snake activations overflow in fp16) - with torch.amp.autocast(device_type=device.type, enabled=False): - audio = diffusion.pretransform.decode(fakes.float()) - - # Offload VAE - if strategy == "offload_to_cpu": - diffusion.pretransform.to(get_offload_device()) - soft_empty_cache() - - # Peak normalize then clamp (matching reference: div by max abs before clamp) - audio = audio.float() - peak = audio.abs().max().clamp(min=1e-8) - audio = (audio / peak).clamp(-1, 1) - - # Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int} - return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) - - -def _substitute_empty_features(diffusion, conditioning, device, dtype): - """Replace sync conditioning with learned empty embedding when video is absent. - - Only substitutes sync_features — NOT video_features. The reference code - (predict.py/app.py) checks for 'metaclip_features' which doesn't exist in the - prismaudio.json config, so video substitution never runs. Cond_MLP with zero - input + bias-free linear layers naturally produces near-zero output. - - The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim]. - """ - dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model - - # Only substitute sync_features (matching reference behavior for prismaudio config) - if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning: - empty = dit.empty_sync_feat.to(device, dtype=dtype) - cond_tensor = conditioning['sync_features'][0] - batch_size = cond_tensor.shape[0] - empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) - conditioning['sync_features'][0] = empty_expanded - conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device) -``` - -**Step 2: Verify the node registers** - -Start ComfyUI, check "PrismAudio Sampler" appears in add-node menu. - -**Step 3: Commit** - -```bash -git add nodes/sampler.py -git commit -m "feat: PrismAudioSampler node with correct metadata format and peak normalization" -``` - ---- - -### Task 9: PrismAudioTextOnly Node - -**Files:** -- Create: `nodes/text_only.py` - -**Step 1: Write the text-only node** - -Key fixes from review: -- Uses `AutoModelForSeq2SeqLM.get_encoder()`, not `AutoModel.encoder` -- No truncation (matching reference) -- Metadata is tuple of dicts with torch.tensor(False) for video_exist -- Zero tensors for video/sync features, not None -- Peak normalization - -```python -import torch -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, - get_device, get_offload_device, soft_empty_cache, resolve_hf_token, -) -from .sampler import _substitute_empty_features - - -class PrismAudioTextOnly: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("PRISMAUDIO_MODEL",), - "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Text description for audio generation"}), - "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), - "steps": ("INT", {"default": 24, "min": 1, "max": 100}), - "cfg_scale": ("FLOAT", {"default": 5.0, "min": 1.0, "max": 20.0, "step": 0.1}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), - }, - } - - RETURN_TYPES = ("AUDIO",) - RETURN_NAMES = ("audio",) - FUNCTION = "generate" - CATEGORY = PRISMAUDIO_CATEGORY - - def generate(self, model, text_prompt, duration, steps, cfg_scale, seed): - device = get_device() - dtype = model["dtype"] - strategy = model["strategy"] - diffusion = model["model"] - - latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) - - # Encode text with T5-Gemma - text_features = _encode_text_t5(text_prompt, device, dtype) - - # Build metadata: tuple of one dict per sample - # Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence) - # Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768] - # These will be substituted with learned empty embeddings after conditioning - sample_meta = { - "video_features": torch.zeros(1, 1024, device=device, dtype=dtype), - "text_features": text_features.to(device, dtype=dtype), - "sync_features": torch.zeros(8, 768, device=device, dtype=dtype), - "video_exist": torch.tensor(False), - } - metadata = (sample_meta,) - - if strategy == "offload_to_cpu": - diffusion.model.to(device) - diffusion.conditioner.to(device) - soft_empty_cache() - - with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): - conditioning = diffusion.conditioner(metadata, device) - - # Substitute empty features for video/sync - _substitute_empty_features(diffusion, conditioning, device, dtype) - - cond_inputs = diffusion.get_conditioning_inputs(conditioning) - - # Generate noise from seed (MPS doesn't support torch.Generator) - gen_device = "cpu" if device.type == "mps" else device - generator = torch.Generator(device=gen_device).manual_seed(seed) - noise = torch.randn( - [1, IO_CHANNELS, latent_length], - generator=generator, - device=gen_device, - ).to(device=device, dtype=dtype) - - pbar = comfy.utils.ProgressBar(steps) - - from prismaudio_core.inference.sampling import sample_discrete_euler - - def on_step(info): - pbar.update(1) - - fakes = sample_discrete_euler( - diffusion.model, - noise, - steps, - callback=on_step, - **cond_inputs, - cfg_scale=cfg_scale, - batch_cfg=True, - ) - - if strategy == "offload_to_cpu": - diffusion.model.to(get_offload_device()) - diffusion.conditioner.to(get_offload_device()) - soft_empty_cache() - diffusion.pretransform.to(device) - - # VAE decode in fp32 (snake activations overflow in fp16) - with torch.amp.autocast(device_type=device.type, enabled=False): - audio = diffusion.pretransform.decode(fakes.float()) - - if strategy == "offload_to_cpu": - diffusion.pretransform.to(get_offload_device()) - soft_empty_cache() - - # Peak normalize then clamp - audio = audio.float() - peak = audio.abs().max().clamp(min=1e-8) - audio = (audio / peak).clamp(-1, 1) - - return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) - - -# T5-Gemma encoder singleton -_t5_model = None -_t5_tokenizer = None - - -def _encode_text_t5(text, device, dtype): - """Encode text using T5-Gemma. - - Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference - FeaturesUtils.encode_t5_text() implementation. - No truncation applied (matching reference behavior). - """ - global _t5_model, _t5_tokenizer - - if _t5_model is None: - from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - model_id = "google/t5gemma-l-l-ul2-it" - token = resolve_hf_token() - print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}") - _t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) - _t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder() - _t5_model.eval() - - _t5_model.to(device, dtype=dtype) - - tokens = _t5_tokenizer( - text, - return_tensors="pt", - padding=True, - ).to(device) - - with torch.no_grad(): - outputs = _t5_model(**tokens) - - # Move T5 off GPU after encoding to save VRAM - _t5_model.to("cpu") - soft_empty_cache() - - return outputs.last_hidden_state.squeeze(0) # [seq_len, dim] -``` - -**Step 2: Commit** - -```bash -git add nodes/text_only.py -git commit -m "feat: PrismAudioTextOnly node with correct T5-Gemma encoding" -``` - ---- - -### Task 10: Integration Testing & Polish - -**Files:** -- Modify: `nodes/__init__.py` (verify all imports work) -- Modify: `__init__.py` (verify top-level registration) - -**Step 1: Verify all node imports resolve** - -Run from ComfyUI's Python: -```bash -cd /path/to/ComfyUI -python -c " -import sys -sys.path.insert(0, 'custom_nodes/ComfyUI-PrismAudio') -from nodes import NODE_CLASS_MAPPINGS -print('Registered nodes:', list(NODE_CLASS_MAPPINGS.keys())) -" -``` - -Expected output: -``` -Registered nodes: ['PrismAudioModelLoader', 'PrismAudioFeatureLoader', 'PrismAudioFeatureExtractor', 'PrismAudioSampler', 'PrismAudioTextOnly'] -``` - -**Step 2: Fix any import errors iteratively** - -Common issues: -- `prismaudio_core` internal imports may reference wrong module paths -- Missing model submodules in `prismaudio_core/models/` -- flash-attn fallback not properly guarded - -**Step 3: Test model loading (requires GPU + model files)** - -```bash -python -c " -from prismaudio_core.factory import create_model_from_config -import json -with open('prismaudio_core/configs/prismaudio.json') as f: - config = json.load(f) -model = create_model_from_config(config) -print('Model created, params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M') -" -``` - -Expected: `Model created, params: ~518 M` - -**Step 4: End-to-end test with pre-computed features** - -If you have a `.npz` feature file from the PrismAudio repo's demo data, test the full pipeline in ComfyUI: -1. PrismAudioModelLoader -> PrismAudioFeatureLoader -> PrismAudioSampler -> Preview Audio node - -**Step 5: Verify variable duration handling** - -Test with multiple durations (5s, 10s, 20s) to ensure the model adapts to different -input shapes and produces audio of the expected length. - -**Step 6: Commit** - -```bash -git add -A -git commit -m "feat: integration fixes and verification" -``` - ---- - -### Task 11: README - -**Files:** -- Create: `README.md` - -**Step 1: Write README covering:** - -- What PrismAudio is (brief, link to paper) -- Installation (clone, pip install requirements, optional extraction env setup) -- Node descriptions with input/output tables -- Example workflows (quality path with FeatureExtractor, quick path with FeatureLoader, text-only) -- HuggingFace authentication (2 methods: `HF_TOKEN` env var, `huggingface-cli login`) - - Note: hf_token is NOT a node widget for security reasons - - Which models may be gated (T5-Gemma, potentially Stable Audio VAE) -- Model file sizes: diffusion ~2.7GB, VAE ~2.5GB, synchformer ~950MB -- Extraction env setup via conda environment.yml -- Troubleshooting (VRAM, flash-attn optional, gated models) -- Credits and license - -**Step 2: Commit** - -```bash -git add README.md -git commit -m "docs: README with installation and usage instructions" -``` - ---- - -## Dependency Graph - -``` -Task 1 (scaffolding) - ├── Task 2 (core config + factory) ──┐ - │ └── Task 3 (core models) ──────┤ - │ └── Task 4 (core sampling)┤ - │ ├── Task 5 (ModelLoader node) - │ ├── Task 6 (FeatureLoader node) - │ ├── Task 7 (FeatureExtractor node) - │ ├── Task 8 (Sampler node) - │ └── Task 9 (TextOnly node) - └────────────────────────────────────────── Task 10 (Integration) - └── Task 11 (README) -``` - -Tasks 5-9 can be parallelized after Task 4 is complete. Task 3 is the heaviest — it involves extracting and adapting ~10 model files. diff --git a/nodes/feature_extractor.py b/nodes/feature_extractor.py deleted file mode 100644 index 724f074..0000000 --- a/nodes/feature_extractor.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -import sys -import hashlib -import subprocess -import tempfile -import torch - -from .utils import PRISMAUDIO_CATEGORY -from .feature_loader import PrismAudioFeatureLoader - -# Managed venv created automatically when python_env is left as default -_PLUGIN_DIR = os.path.dirname(os.path.dirname(__file__)) -_MANAGED_VENV = os.path.join(_PLUGIN_DIR, "_extract_env") -_MANAGED_PYTHON = os.path.join(_MANAGED_VENV, "bin", "python") - -_EXTRACT_PACKAGES = [ - "torch", "torchaudio", "torchvision", - # TF 2.15 only supports Python <=3.11; use >=2.16 for Python 3.12+ - "tensorflow-cpu>=2.16.0", - # jax[cuda13] includes jaxlib; pip-managed CUDA libs (no local toolkit needed) - "jax[cuda13]", "flax", - "transformers", "decord", "einops", "numpy", "mediapy", - "git+https://github.com/google-deepmind/videoprism.git", -] - - -def _pip_install(pip, *packages, label=None): - """Install one or more packages with visible output; raise on failure.""" - tag = label or packages[0] - print(f"[PrismAudio] installing {tag} ...", flush=True) - result = subprocess.run( - [pip, "install", "--progress-bar", "on"] + list(packages), - capture_output=False, - ) - if result.returncode != 0: - raise RuntimeError( - f"[PrismAudio] Failed to install {tag} (exit {result.returncode}). " - "See pip output above for details." - ) - print(f"[PrismAudio] {tag} OK", flush=True) - - -def _ensure_extract_env(): - """Create and populate the managed venv on first use.""" - if os.path.exists(_MANAGED_PYTHON): - return _MANAGED_PYTHON - - import shutil - if os.path.exists(_MANAGED_VENV): - print("[PrismAudio] Removing incomplete venv and retrying...", flush=True) - shutil.rmtree(_MANAGED_VENV) - - print(f"[PrismAudio] Creating feature-extraction venv at: {_MANAGED_VENV}", flush=True) - subprocess.run([sys.executable, "-m", "venv", _MANAGED_VENV], check=True) - - pip = os.path.join(_MANAGED_VENV, "bin", "pip") - - print("[PrismAudio] Upgrading pip...", flush=True) - subprocess.run([pip, "install", "--upgrade", "pip"], check=True) - - total = len(_EXTRACT_PACKAGES) - print(f"[PrismAudio] Installing {total} package groups — this may take several minutes...", flush=True) - - for i, pkg in enumerate(_EXTRACT_PACKAGES, 1): - label = pkg.split("/")[-1] if pkg.startswith("git+") else pkg.split(">=")[0].split("==")[0].split("[")[0] - print(f"[PrismAudio] [{i}/{total}] {label}", flush=True) - _pip_install(pip, pkg, label=label) - - print("[PrismAudio] Feature-extraction env ready.", flush=True) - return _MANAGED_PYTHON - - -def _hash_inputs(video_tensor, cot_text): - """Create a hash of the inputs for caching.""" - h = hashlib.sha256() - h.update(video_tensor.cpu().numpy().tobytes()[:1024 * 1024]) # First 1MB for speed - h.update(cot_text.encode()) - return h.hexdigest()[:16] - - -def _save_frames_to_npy(video_tensor, output_path): - """Save ComfyUI IMAGE tensor [T,H,W,C] float32 [0,1] to .npy as uint8. - - Lossless — avoids H.264 encode/decode roundtrip. - """ - import numpy as np - frames_np = (video_tensor.cpu().numpy() * 255).astype("uint8") - np.save(output_path, frames_np) - - -class PrismAudioFeatureExtractor: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": ("IMAGE",), - "caption_cot": ("STRING", {"default": "", "multiline": True, "tooltip": "Chain-of-thought description"}), - }, - "optional": { - "video_info": ("VHS_VIDEOINFO", {"tooltip": "Connect VHS LoadVideo info output to auto-set fps."}), - "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 0.001, "tooltip": "Frame rate of the input video. Ignored if video_info is connected."}), - "python_env": (["managed_env", "comfyui_env"], {"tooltip": "managed_env: auto-created isolated venv with JAX/TF (recommended). comfyui_env: current ComfyUI Python — WARNING: may conflict with existing packages and destabilize ComfyUI."}), - "cache_dir": ("STRING", {"default": "", "tooltip": "Directory to cache extracted features. Empty = temp dir"}), - "hf_token": ("STRING", {"default": "", "tooltip": "HuggingFace token for gated models (e.g. google/t5gemma). Get yours at huggingface.co/settings/tokens"}), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_FEATURES", "FLOAT") - RETURN_NAMES = ("features", "fps") - FUNCTION = "extract_features" - CATEGORY = PRISMAUDIO_CATEGORY - - def extract_features(self, video, caption_cot, video_info=None, fps=30.0, python_env="managed_env", cache_dir="", hf_token=""): - # Resolve fps from VHS video_info if connected - if video_info is not None: - fps = video_info["loaded_fps"] - - # Resolve python binary - if python_env == "comfyui_env": - print("[PrismAudio] WARNING: using ComfyUI Python env — JAX/TF/videoprism must already be installed. " - "Installing them here may conflict with existing packages and destabilize ComfyUI.", flush=True) - python_bin = sys.executable - else: - python_bin = _ensure_extract_env() - - # Determine cache directory - if not cache_dir: - cache_dir = os.path.join(tempfile.gettempdir(), "prismaudio_features") - os.makedirs(cache_dir, exist_ok=True) - - # Check cache - cache_hash = _hash_inputs(video, caption_cot) - cached_path = os.path.join(cache_dir, f"{cache_hash}.npz") - if os.path.exists(cached_path): - print(f"[PrismAudio] Using cached features: {cached_path}") - loader = PrismAudioFeatureLoader() - features, = loader.load_features(cached_path) - return (features, float(fps)) - - # Save frames to temp file (lossless .npy, no codec roundtrip) - import time - t0 = time.perf_counter() - frames = video.shape[0] - print(f"[PrismAudio] Saving {frames} frames to .npy (fps={fps})...", flush=True) - with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp: - tmp_video = tmp.name - _save_frames_to_npy(video, tmp_video) - print(f"[PrismAudio] Frames saved in {time.perf_counter() - t0:.1f}s", flush=True) - - # Build subprocess command - script_path = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "scripts", "extract_features.py" - ) - - import folder_paths - synchformer_ckpt = os.path.join(folder_paths.models_dir, "prismaudio", "synchformer_state_dict.pth") - if not os.path.exists(synchformer_ckpt): - raise RuntimeError( - f"[PrismAudio] Synchformer checkpoint not found: {synchformer_ckpt}\n" - "Download synchformer_state_dict.pth from FunAudioLLM/PrismAudio and place it in models/prismaudio/." - ) - - cmd = [ - python_bin, - script_path, - "--video", tmp_video, - "--cot_text", caption_cot, - "--output", cached_path, - "--source_fps", str(fps), - "--synchformer_ckpt", synchformer_ckpt, - ] - - # Build env: inherit current env, inject HF token if provided - import copy - env = copy.copy(os.environ) - token = hf_token.strip() if hf_token else os.environ.get("HF_TOKEN", "") - if token: - env["HF_TOKEN"] = token - env["HUGGING_FACE_HUB_TOKEN"] = token - else: - print("[PrismAudio] Warning: no HF_TOKEN set — gated models (e.g. t5gemma) will fail. " - "Add your token in the hf_token input or set HF_TOKEN env var.", flush=True) - - print(f"[PrismAudio] Extracting features via subprocess (output streams live)...") - try: - # capture_output=False: let stdout/stderr stream directly to ComfyUI logs - result = subprocess.run( - cmd, - capture_output=False, - timeout=600, # 10 minute timeout - env=env, - ) - if result.returncode != 0: - raise RuntimeError( - f"[PrismAudio] Feature extraction subprocess exited with code {result.returncode}. " - "See output above for details." - ) - print("[PrismAudio] Feature extraction subprocess finished successfully.") - finally: - if os.path.exists(tmp_video): - os.unlink(tmp_video) - - # Load the extracted features - loader = PrismAudioFeatureLoader() - features, = loader.load_features(cached_path) - return (features, float(fps)) diff --git a/nodes/feature_loader.py b/nodes/feature_loader.py deleted file mode 100644 index 6c57901..0000000 --- a/nodes/feature_loader.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -import numpy as np -import torch -from .utils import PRISMAUDIO_CATEGORY - -# Keys consumed by the conditioners (video_features, text_features, sync_features) -# global_video_features and global_text_features are NOT consumed by any conditioner -# in the prismaudio.json config — they are unused. -REQUIRED_KEYS = [ - "video_features", - "text_features", - "sync_features", -] - - -class PrismAudioFeatureLoader: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "npz_path": ("STRING", {"default": "", "tooltip": "Path to pre-computed .npz feature file"}), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_FEATURES",) - RETURN_NAMES = ("features",) - FUNCTION = "load_features" - CATEGORY = PRISMAUDIO_CATEGORY - - def load_features(self, npz_path): - if not os.path.exists(npz_path): - raise FileNotFoundError(f"[PrismAudio] Feature file not found: {npz_path}") - - data = np.load(npz_path, allow_pickle=True) - - features = {} - for key in REQUIRED_KEYS: - if key in data: - features[key] = torch.from_numpy(data[key]).float() - else: - print(f"[PrismAudio] Warning: key '{key}' not found in {npz_path}, using zeros") - # Provide zero tensor rather than None — Cond_MLP/Sync_MLP crash on None - # Sync_MLP requires length divisible by 8 (segments of 8 frames) - if key == "sync_features": - features[key] = torch.zeros(8, 768) - else: - features[key] = torch.zeros(1, 1024) - - # Load duration if present - if "duration" in data: - features["duration"] = float(data["duration"]) - - return (features,) diff --git a/nodes/model_loader.py b/nodes/model_loader.py deleted file mode 100644 index a004ba9..0000000 --- a/nodes/model_loader.py +++ /dev/null @@ -1,154 +0,0 @@ -import os -import json -import torch -import folder_paths -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, get_prismaudio_model_dir, register_model_folder, - get_device, get_offload_device, determine_precision, determine_offload_strategy, - soft_empty_cache, resolve_hf_token, -) - -# HuggingFace repo for auto-download -HF_REPO_ID = "FunAudioLLM/PrismAudio" -REQUIRED_FILES = { - "diffusion": "prismaudio.ckpt", - "vae": "vae.ckpt", - "synchformer": "synchformer_state_dict.pth", -} - - -def _download_if_missing(filename, model_dir, hf_token=None): - """Download a model file from HuggingFace if not present locally.""" - filepath = os.path.join(model_dir, filename) - if os.path.exists(filepath): - return filepath - - from huggingface_hub import hf_hub_download - print(f"[PrismAudio] Downloading {filename} from {HF_REPO_ID}...") - try: - downloaded = hf_hub_download( - repo_id=HF_REPO_ID, - filename=filename, - local_dir=model_dir, - token=hf_token or None, - ) - return downloaded - except Exception as e: - if "401" in str(e) or "403" in str(e) or "gated" in str(e).lower(): - raise RuntimeError( - f"[PrismAudio] Model '{filename}' requires license acceptance. " - f"Visit https://huggingface.co/{HF_REPO_ID} to accept the license, " - f"then set HF_TOKEN env var or run: huggingface-cli login" - ) from e - raise - - -class PrismAudioModelLoader: - @classmethod - def INPUT_TYPES(cls): - register_model_folder() - return { - "required": { - "precision": (["auto", "fp32", "fp16", "bf16"],), - "offload_strategy": (["auto", "keep_in_vram", "offload_to_cpu"],), - }, - } - - RETURN_TYPES = ("PRISMAUDIO_MODEL",) - RETURN_NAMES = ("model",) - FUNCTION = "load_model" - CATEGORY = PRISMAUDIO_CATEGORY - - def load_model(self, precision, offload_strategy): - device = get_device() - dtype = determine_precision(precision, device) - strategy = determine_offload_strategy(offload_strategy) - token = resolve_hf_token() - model_dir = get_prismaudio_model_dir() - - # Auto-download missing files - for key, filename in REQUIRED_FILES.items(): - _download_if_missing(filename, model_dir, hf_token=token) - - # Load config - config_path = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "prismaudio_core", "configs", "prismaudio.json" - ) - with open(config_path) as f: - model_config = json.load(f) - - # Create model from config - from prismaudio_core.factory import create_model_from_config - model = create_model_from_config(model_config) - - # Load diffusion weights - diffusion_path = os.path.join(model_dir, REQUIRED_FILES["diffusion"]) - diffusion_state = comfy.utils.load_torch_file(diffusion_path) - # Handle wrapped state dicts: some ckpts wrap in {"state_dict": ...} - if "state_dict" in diffusion_state: - diffusion_state = diffusion_state["state_dict"] - diff_result = model.load_state_dict(diffusion_state, strict=False) - print(f"[PrismAudio] Diffusion ckpt: {len(diffusion_state)} keys in file", flush=True) - print(f"[PrismAudio] Diffusion load: missing={len(diff_result.missing_keys)}, unexpected={len(diff_result.unexpected_keys)}", flush=True) - if diff_result.missing_keys: - print(f"[PrismAudio] missing (first 10): {diff_result.missing_keys[:10]}", flush=True) - if diff_result.unexpected_keys: - print(f"[PrismAudio] unexpected (first 5): {diff_result.unexpected_keys[:5]}", flush=True) - # Sample a few ckpt keys to verify prefix alignment - sample_keys = list(diffusion_state.keys())[:5] - print(f"[PrismAudio] ckpt key samples: {sample_keys}", flush=True) - - # Load VAE weights separately - # Use comfy.utils.load_torch_file for consistency and PyTorch 2.6+ compat - vae_path = os.path.join(model_dir, REQUIRED_FILES["vae"]) - vae_full_state = comfy.utils.load_torch_file(vae_path) - print(f"[PrismAudio] VAE ckpt: {len(vae_full_state)} keys in file", flush=True) - # Sample raw keys to see actual prefix - vae_sample_keys = list(vae_full_state.keys())[:8] - print(f"[PrismAudio] VAE raw key samples: {vae_sample_keys}", flush=True) - # Strip "autoencoder." prefix from keys - vae_state = {} - prefix = "autoencoder." - for k, v in vae_full_state.items(): - if k.startswith(prefix): - vae_state[k[len(prefix):]] = v - else: - vae_state[k] = v - print(f"[PrismAudio] VAE after strip: {len(vae_state)} keys", flush=True) - # Sample model keys to compare - model_vae_keys = list(model.pretransform.state_dict().keys())[:5] - print(f"[PrismAudio] pretransform model key samples: {model_vae_keys}", flush=True) - # strict=False: vae.ckpt is a training checkpoint that also contains - # discriminator, loss modules, and EMA wrappers not present in the - # inference AudioAutoencoder — ignore those extra keys. - # Load directly into the inner AudioAutoencoder to get IncompatibleKeys back - # (AutoencoderPretransform.load_state_dict doesn't return the result) - vae_result = model.pretransform.model.load_state_dict(vae_state, strict=False) - print(f"[PrismAudio] VAE load: missing={len(vae_result.missing_keys)}, unexpected={len(vae_result.unexpected_keys)}", flush=True) - if vae_result.missing_keys: - print(f"[PrismAudio] VAE missing (first 10): {vae_result.missing_keys[:10]}", flush=True) - - # Apply precision: DiT + conditioners in user-selected dtype, - # but keep VAE (pretransform) in fp32 to avoid NaN from snake activations in fp16 - model.model.to(dtype) # DiTWrapper - model.conditioner.to(dtype) # MultiConditioner - # model.pretransform stays in fp32 - - if strategy == "keep_in_vram": - model = model.to(device) - else: - model = model.to(get_offload_device()) - - model.eval() - - return ({ - "model": model, - "dtype": dtype, - "strategy": strategy, - "config": model_config, - "model_dir": model_dir, - },) diff --git a/nodes/sampler.py b/nodes/sampler.py deleted file mode 100644 index 1aaf90f..0000000 --- a/nodes/sampler.py +++ /dev/null @@ -1,165 +0,0 @@ -import torch -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, - get_device, get_offload_device, soft_empty_cache, -) - - -class PrismAudioSampler: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("PRISMAUDIO_MODEL",), - "features": ("PRISMAUDIO_FEATURES",), - "duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 30.0, "step": 0.1, "tooltip": "Audio duration in seconds. Set to 0 to use the video duration from features automatically."}), - "steps": ("INT", {"default": 100, "min": 1, "max": 100, "tooltip": "Number of sampling steps"}), - "cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1, "tooltip": "Classifier-free guidance scale"}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), - }, - } - - RETURN_TYPES = ("AUDIO",) - RETURN_NAMES = ("audio",) - FUNCTION = "generate" - CATEGORY = PRISMAUDIO_CATEGORY - - def generate(self, model, features, duration, steps, cfg_scale, seed): - device = get_device() - dtype = model["dtype"] - strategy = model["strategy"] - diffusion = model["model"] - - # Resolve duration: 0 means use video duration from features - if duration <= 0: - if "duration" not in features: - raise ValueError("[PrismAudio] duration=0 but features contain no duration. Set duration manually or use PrismAudioFeatureExtractor.") - duration = features["duration"] - print(f"[PrismAudio] Using video duration from features: {duration:.2f}s", flush=True) - - # Compute latent dimensions - latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) - - # Note: no seq length config needed — the model adapts to input tensor shapes - # dynamically via its transformer architecture. - - # Determine if video features are present (not all zeros) - has_video = features.get("video_features") is not None and features["video_features"].abs().sum() > 0 - - video_feat = features["video_features"].to(device, dtype=dtype) - sync_feat = features["sync_features"].to(device, dtype=dtype) - - # Build metadata as a TUPLE of dicts (one per batch sample) - # MultiConditioner.forward(batch_metadata: List[Dict]) iterates over this - sample_meta = { - "video_features": video_feat, - "text_features": features["text_features"].to(device, dtype=dtype), - "sync_features": sync_feat, - "video_exist": torch.tensor(has_video), - } - metadata = (sample_meta,) - - # Move model to device if offloaded - if strategy == "offload_to_cpu": - diffusion.model.to(device) - diffusion.conditioner.to(device) - soft_empty_cache() - - with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): - # Run conditioning - conditioning = diffusion.conditioner(metadata, device) - - # Handle missing video: substitute learned empty embeddings - if not has_video: - _substitute_empty_features(diffusion, conditioning, device, dtype) - - # Assemble conditioning inputs for the DiT - cond_inputs = diffusion.get_conditioning_inputs(conditioning) - - # Generate noise from seed (MPS doesn't support torch.Generator) - gen_device = "cpu" if device.type == "mps" else device - generator = torch.Generator(device=gen_device).manual_seed(seed) - noise = torch.randn( - [1, IO_CHANNELS, latent_length], - generator=generator, - device=gen_device, - ).to(device=device, dtype=dtype) - - # Sample with progress bar - pbar = comfy.utils.ProgressBar(steps) - - from prismaudio_core.inference.sampling import sample_discrete_euler - - def on_step(info): - pbar.update(1) - - fakes = sample_discrete_euler( - diffusion.model, - noise, - steps, - callback=on_step, - **cond_inputs, - cfg_scale=cfg_scale, - batch_cfg=True, - ) - - fakes_f = fakes.float() - print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True) - - # Offload diffusion model and conditioner before VAE decode - if strategy == "offload_to_cpu": - diffusion.model.to(get_offload_device()) - diffusion.conditioner.to(get_offload_device()) - soft_empty_cache() - diffusion.pretransform.to(device) - - # VAE decode in fp32 (snake activations overflow in fp16) - with torch.amp.autocast(device_type=device.type, enabled=False): - audio = diffusion.pretransform.decode(fakes_f) - - # Offload VAE - if strategy == "offload_to_cpu": - diffusion.pretransform.to(get_offload_device()) - soft_empty_cache() - - # Peak normalize then clamp (matching reference: div by max abs before clamp) - audio = audio.float() - pre_norm_std = audio.std().item() - pre_norm_peak = audio.abs().max().item() - peak = audio.abs().max().clamp(min=1e-8) - audio = (audio / peak).clamp(-1, 1) - print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True) - - # Return as ComfyUI AUDIO: {"waveform": [B, channels, samples], "sample_rate": int} - return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) - - -def _substitute_empty_features(diffusion, conditioning, device, dtype): - """Replace video/sync conditioning with learned empty embeddings when video is absent. - - empty_clip_feat and empty_sync_feat are learned null embeddings in the conditioner - output space (1024-dim). Passing zero features through bias-free Cond_MLP produces - near-zero activations, NOT the learned null signal the model was trained with. - - The conditioner returns {key: [tensor, mask]} where tensor is [B, seq, dim]. - """ - dit = diffusion.model.model if hasattr(diffusion.model, 'model') else diffusion.model - - # Substitute video_features with learned empty_clip_feat - if hasattr(dit, 'empty_clip_feat') and 'video_features' in conditioning: - empty = dit.empty_clip_feat.to(device, dtype=dtype) # [1, 1024] - batch_size = conditioning['video_features'][0].shape[0] - empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024] - conditioning['video_features'][0] = empty_expanded - conditioning['video_features'][1] = torch.ones(batch_size, 1, device=device) - - # Substitute sync_features with learned empty_sync_feat - if hasattr(dit, 'empty_sync_feat') and 'sync_features' in conditioning: - empty = dit.empty_sync_feat.to(device, dtype=dtype) # [1, 1024] - batch_size = conditioning['sync_features'][0].shape[0] - empty_expanded = empty.unsqueeze(0).expand(batch_size, -1, -1) # [B, 1, 1024] - conditioning['sync_features'][0] = empty_expanded - conditioning['sync_features'][1] = torch.ones(batch_size, 1, device=device) diff --git a/nodes/selva_feature_extractor.py b/nodes/selva_feature_extractor.py index 995adf9..c0ad5ae 100644 --- a/nodes/selva_feature_extractor.py +++ b/nodes/selva_feature_extractor.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn.functional as F -from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache +from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache # SelVA video preprocessing constants (from selva/utils/eval_utils.py) _CLIP_SIZE = 384 @@ -68,7 +68,7 @@ class SelvaFeatureExtractor: RETURN_TYPES = ("SELVA_FEATURES", "FLOAT", "STRING") RETURN_NAMES = ("features", "fps", "prompt") FUNCTION = "extract_features" - CATEGORY = PRISMAUDIO_CATEGORY + CATEGORY = SELVA_CATEGORY def extract_features(self, model, video, prompt, video_info=None, fps=30.0, duration=0.0, cache_dir=""): diff --git a/nodes/selva_model_loader.py b/nodes/selva_model_loader.py index e2ea162..dedfab5 100644 --- a/nodes/selva_model_loader.py +++ b/nodes/selva_model_loader.py @@ -3,7 +3,7 @@ from pathlib import Path import torch import folder_paths -from .utils import PRISMAUDIO_CATEGORY, get_offload_device, determine_offload_strategy +from .utils import SELVA_CATEGORY, get_offload_device, determine_offload_strategy # Variant → (generator filename, mode, has_bigvgan) _VARIANTS = { @@ -96,7 +96,7 @@ class SelvaModelLoader: RETURN_TYPES = ("SELVA_MODEL",) RETURN_NAMES = ("model",) FUNCTION = "load_model" - CATEGORY = PRISMAUDIO_CATEGORY + CATEGORY = SELVA_CATEGORY def load_model(self, variant, precision, offload_strategy): from selva_core.model.networks_generator import get_my_mmaudio diff --git a/nodes/selva_sampler.py b/nodes/selva_sampler.py index 15bcca1..f1d172f 100644 --- a/nodes/selva_sampler.py +++ b/nodes/selva_sampler.py @@ -1,7 +1,7 @@ import torch import comfy.utils -from .utils import PRISMAUDIO_CATEGORY, get_device, get_offload_device, soft_empty_cache +from .utils import SELVA_CATEGORY, get_device, get_offload_device, soft_empty_cache class SelvaSampler: @@ -35,7 +35,7 @@ class SelvaSampler: RETURN_TYPES = ("AUDIO",) RETURN_NAMES = ("audio",) FUNCTION = "generate" - CATEGORY = PRISMAUDIO_CATEGORY + CATEGORY = SELVA_CATEGORY def generate(self, model, features, prompt, negative_prompt, duration, steps, cfg_strength, seed): from selva_core.model.flow_matching import FlowMatching diff --git a/nodes/text_only.py b/nodes/text_only.py deleted file mode 100644 index 3197ff6..0000000 --- a/nodes/text_only.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -import comfy.model_management as mm -import comfy.utils - -from .utils import ( - PRISMAUDIO_CATEGORY, SAMPLE_RATE, DOWNSAMPLING_RATIO, IO_CHANNELS, - get_device, get_offload_device, soft_empty_cache, resolve_hf_token, -) -from .sampler import _substitute_empty_features - - -class PrismAudioTextOnly: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("PRISMAUDIO_MODEL",), - "text_prompt": ("STRING", {"default": "", "multiline": True, "tooltip": "Detailed chain-of-thought description of the audio scene. Use long, descriptive text — e.g. 'A large dog barks sharply twice, with ambient outdoor background noise. The sound is clear and close.' Short prompts produce lower quality."}), - "duration": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), - "steps": ("INT", {"default": 100, "min": 1, "max": 100}), - "cfg_scale": ("FLOAT", {"default": 7.0, "min": 1.0, "max": 20.0, "step": 0.1}), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFFFF}), - }, - } - - RETURN_TYPES = ("AUDIO",) - RETURN_NAMES = ("audio",) - FUNCTION = "generate" - CATEGORY = PRISMAUDIO_CATEGORY - - def generate(self, model, text_prompt, duration, steps, cfg_scale, seed): - device = get_device() - dtype = model["dtype"] - strategy = model["strategy"] - diffusion = model["model"] - - latent_length = round(SAMPLE_RATE * duration / DOWNSAMPLING_RATIO) - - # Encode text with T5-Gemma - text_features = _encode_text_t5(text_prompt, device, dtype) - - # Build metadata: tuple of one dict per sample - # Use zero tensors for video/sync (not None — Cond_MLP crashes on None via pad_sequence) - # Sync_MLP requires length divisible by 8 (segments of 8 frames) — minimum [8, 768] - # These will be substituted with learned empty embeddings after conditioning - sample_meta = { - "video_features": torch.zeros(1, 1024, device=device, dtype=dtype), - "text_features": text_features.to(device, dtype=dtype), - "sync_features": torch.zeros(8, 768, device=device, dtype=dtype), - "video_exist": torch.tensor(False), - } - metadata = (sample_meta,) - - if strategy == "offload_to_cpu": - diffusion.model.to(device) - diffusion.conditioner.to(device) - soft_empty_cache() - - with torch.no_grad(), torch.amp.autocast(device_type=device.type, dtype=dtype): - conditioning = diffusion.conditioner(metadata, device) - - # Substitute empty features for video/sync - _substitute_empty_features(diffusion, conditioning, device, dtype) - - cond_inputs = diffusion.get_conditioning_inputs(conditioning) - - # Generate noise from seed (MPS doesn't support torch.Generator) - gen_device = "cpu" if device.type == "mps" else device - generator = torch.Generator(device=gen_device).manual_seed(seed) - noise = torch.randn( - [1, IO_CHANNELS, latent_length], - generator=generator, - device=gen_device, - ).to(device=device, dtype=dtype) - - pbar = comfy.utils.ProgressBar(steps) - - from prismaudio_core.inference.sampling import sample_discrete_euler - - def on_step(info): - pbar.update(1) - - fakes = sample_discrete_euler( - diffusion.model, - noise, - steps, - callback=on_step, - **cond_inputs, - cfg_scale=cfg_scale, - batch_cfg=True, - ) - - fakes_f = fakes.float() - print(f"[PrismAudio] latent stats: shape={tuple(fakes_f.shape)} mean={fakes_f.mean():.4f} std={fakes_f.std():.4f} min={fakes_f.min():.4f} max={fakes_f.max():.4f}", flush=True) - - if strategy == "offload_to_cpu": - diffusion.model.to(get_offload_device()) - diffusion.conditioner.to(get_offload_device()) - soft_empty_cache() - diffusion.pretransform.to(device) - - # VAE decode in fp32 (snake activations overflow in fp16) - with torch.amp.autocast(device_type=device.type, enabled=False): - audio = diffusion.pretransform.decode(fakes_f) - - if strategy == "offload_to_cpu": - diffusion.pretransform.to(get_offload_device()) - soft_empty_cache() - - # Peak normalize then clamp - audio = audio.float() - pre_norm_std = audio.std().item() - pre_norm_peak = audio.abs().max().item() - peak = audio.abs().max().clamp(min=1e-8) - audio = (audio / peak).clamp(-1, 1) - print(f"[PrismAudio] audio stats (pre-norm): std={pre_norm_std:.4f} peak={pre_norm_peak:.4f}", flush=True) - print(f"[PrismAudio] audio shape: {tuple(audio.shape)}", flush=True) - - return ({"waveform": audio.cpu(), "sample_rate": SAMPLE_RATE},) - - -# T5-Gemma encoder singleton -_t5_model = None -_t5_tokenizer = None - - -def _encode_text_t5(text, device, dtype): - """Encode text using T5-Gemma. - - Uses AutoModelForSeq2SeqLM.get_encoder() to match the reference - FeaturesUtils.encode_t5_text() implementation. - No truncation applied (matching reference behavior). - """ - global _t5_model, _t5_tokenizer - - if _t5_model is None: - from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - model_id = "google/t5gemma-l-l-ul2-it" - token = resolve_hf_token() - print(f"[PrismAudio] Loading T5-Gemma text encoder: {model_id}") - _t5_tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) - _t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, token=token).get_encoder() - _t5_model.eval() - - _t5_model.to(device, dtype=dtype) - - tokens = _t5_tokenizer( - text, - return_tensors="pt", - padding=True, - ).to(device) - - with torch.no_grad(): - outputs = _t5_model(**tokens) - - # Move T5 off GPU after encoding to save VRAM - _t5_model.to("cpu") - soft_empty_cache() - - return outputs.last_hidden_state.squeeze(0) # [seq_len, dim] diff --git a/nodes/utils.py b/nodes/utils.py index e016ad6..031b00c 100644 --- a/nodes/utils.py +++ b/nodes/utils.py @@ -1,21 +1,7 @@ -import os import torch -import folder_paths import comfy.model_management as mm -PRISMAUDIO_CATEGORY = "PrismAudio" -SAMPLE_RATE = 44100 -DOWNSAMPLING_RATIO = 2048 -IO_CHANNELS = 64 - -def get_prismaudio_model_dir(): - model_dir = os.path.join(folder_paths.models_dir, "prismaudio") - os.makedirs(model_dir, exist_ok=True) - return model_dir - -def register_model_folder(): - model_dir = get_prismaudio_model_dir() - folder_paths.add_model_folder_path("prismaudio", model_dir) +SELVA_CATEGORY = "SelVA" def get_device(): return mm.get_torch_device() @@ -23,42 +9,13 @@ def get_device(): def get_offload_device(): return mm.unet_offload_device() -def get_free_memory(device=None): - if device is None: - device = get_device() - return mm.get_free_memory(device) - def soft_empty_cache(): mm.soft_empty_cache() -def determine_precision(preference, device): - if preference != "auto": - return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[preference] - if device.type == "cpu": - return torch.float32 - if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): - return torch.bfloat16 - return torch.float16 - def determine_offload_strategy(preference): if preference != "auto": return preference - free_mem = get_free_memory() - gb = free_mem / (1024 ** 3) - if gb >= 24: + free_mem = mm.get_free_memory(get_device()) + if free_mem / (1024 ** 3) >= 16: return "keep_in_vram" - else: - return "offload_to_cpu" - -def try_import_flash_attn(): - try: - import flash_attn - return flash_attn - except ImportError: - return None - -def resolve_hf_token(): - env_token = os.environ.get("HF_TOKEN") - if env_token: - return env_token - return None + return "offload_to_cpu" diff --git a/prismaudio_core/__init__.py b/prismaudio_core/__init__.py deleted file mode 100644 index 064b327..0000000 --- a/prismaudio_core/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -PrismAudio core inference modules. -Extracted from https://github.com/FunAudioLLM/ThinkSound (prismaudio branch). -Only inference-critical code — no training, no JAX/TF dependencies. -""" diff --git a/prismaudio_core/configs/prismaudio.json b/prismaudio_core/configs/prismaudio.json deleted file mode 100644 index 19d24a0..0000000 --- a/prismaudio_core/configs/prismaudio.json +++ /dev/null @@ -1,141 +0,0 @@ -{ - "model_type": "diffusion_cond", - "sample_size": 397312, - "sample_rate": 44100, - "audio_channels": 2, - "model": { - "pretransform": { - "type": "autoencoder", - "iterate_batch": true, - "config": { - "encoder": { - "type": "oobleck", - "config": { - "in_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 128, - "use_snake": true - } - }, - "decoder": { - "type": "oobleck", - "config": { - "out_channels": 2, - "channels": 128, - "c_mults": [1, 2, 4, 8, 16], - "strides": [2, 4, 4, 8, 8], - "latent_dim": 64, - "use_snake": true, - "final_tanh": false - } - }, - "bottleneck": { - "type": "vae" - }, - "latent_dim": 64, - "downsampling_ratio": 2048, - "io_channels": 2 - } - }, - "conditioning": { - "configs": [ - { - "id": "video_features", - "type": "cond_mlp", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "text_features", - "type": "cond_mlp", - "config": { - "dim": 1024, - "output_dim": 1024 - } - }, - { - "id": "sync_features", - "type": "sync_mlp", - "config": { - "dim": 768, - "output_dim": 1024 - } - } - ], - "cond_dim": 768 - }, - "diffusion": { - "cross_attention_cond_ids": ["video_features","text_features"], - "add_cond_ids": ["video_features"], - "sync_cond_ids": ["sync_features"], - "type": "dit", - "diffusion_objective": "rectified_flow", - "config": { - "io_channels": 64, - "embed_dim": 1024, - "depth": 24, - "num_heads": 16, - "cond_token_dim": 1024, - "add_token_dim": 1024, - "sync_token_dim": 1024, - "project_cond_tokens": false, - "transformer_type": "continuous_transformer", - "attn_kwargs":{ - "qk_norm": "rns" - }, - "use_gated": true, - "use_sync_gated": true - } - }, - "io_channels": 64 - }, - "training": { - "use_ema": true, - "log_loss_info": false, - "cfg_dropout_prob": 0.1, - "pre_encoded": true, - "timestep_sampler": "trunc_logit_normal", - "optimizer_configs": { - "diffusion": { - "optimizer": { - "type": "AdamW", - "config": { - "lr": 1e-4, - "betas": [0.9, 0.999], - "weight_decay": 1e-3 - } - }, - "scheduler": { - "type": "InverseLR", - "config": { - "inv_gamma": 100000, - "power": 0.5, - "warmup": 0.99 - } - } - } - }, - "demo": { - "demo_every": 5000, - "demo_steps": 24, - "num_demos": 10, - "demo_cond": [ - "dataset/videoprism/test/0Cu33yBwAPg_000060.npz", - "dataset/videoprism/test/bmKtI808DsU_000009.npz", - "dataset/videoprism/test/VC0c22cJTbM_000424.npz", - "dataset/videoprism/test/F3gsbUTdc2U_000090.npz", - "dataset/videoprism/test/WatvT8A8iug_000100.npz", - "dataset/videoprism/test/0nvBTp-q7tU_000112.npz", - "dataset/videoprism/test/3-PFuDkTM48_000080.npz", - "dataset/videoprism/test/luSAuu-BoPs_000232.npz", - "dataset/videoprism/test/__8UJxW0aOQ_000002.npz", - "dataset/videoprism/test/_0m_YMpQayA_000168.npz" - ], - "demo_cfg_scales": [5] - } - } -} \ No newline at end of file diff --git a/prismaudio_core/factory.py b/prismaudio_core/factory.py deleted file mode 100644 index 7eeef44..0000000 --- a/prismaudio_core/factory.py +++ /dev/null @@ -1,413 +0,0 @@ -""" -Model factory functions for PrismAudio inference. - -Extracted from: - - PrismAudio/models/factory.py - - PrismAudio/models/autoencoders.py (create_autoencoder_from_config) - - PrismAudio/models/diffusion.py (create_diffusion_cond_from_config) - - PrismAudio/models/conditioners.py (create_multi_conditioner_from_conditioning_config) - -Source: https://github.com/FunAudioLLM/ThinkSound (prismaudio branch) -Only inference-critical factory functions are retained. -""" - -import json -import typing as tp -from typing import Dict, Any - -import numpy as np - - -def create_model_from_config(model_config): - model_type = model_config.get('model_type', None) - - assert model_type is not None, 'model_type must be specified in model config' - - if model_type == 'autoencoder': - return create_autoencoder_from_config(model_config) - elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond": - return create_diffusion_cond_from_config(model_config) - else: - raise NotImplementedError(f'Unknown model type: {model_type}') - - -def create_pretransform_from_config(pretransform_config, sample_rate): - pretransform_type = pretransform_config.get('type', None) - - assert pretransform_type is not None, 'type must be specified in pretransform config' - - if pretransform_type == 'autoencoder': - from prismaudio_core.models.pretransforms import AutoencoderPretransform - - # Create fake top-level config to pass sample rate to autoencoder constructor - # This is a bit of a hack but it keeps us from re-defining the sample rate in the config - autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} - autoencoder = create_autoencoder_from_config(autoencoder_config) - - scale = pretransform_config.get("scale", 1.0) - model_half = pretransform_config.get("model_half", False) - iterate_batch = pretransform_config.get("iterate_batch", False) - chunked = pretransform_config.get("chunked", False) - - pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) - elif pretransform_type == 'wavelet': - raise NotImplementedError("wavelet pretransform type is not supported") - elif pretransform_type == 'pqmf': - from prismaudio_core.models.pretransforms import PQMFPretransform - pqmf_config = pretransform_config["config"] - pretransform = PQMFPretransform(**pqmf_config) - elif pretransform_type == 'dac_pretrained': - from prismaudio_core.models.pretransforms import PretrainedDACPretransform - pretrained_dac_config = pretransform_config["config"] - pretransform = PretrainedDACPretransform(**pretrained_dac_config) - elif pretransform_type == "audiocraft_pretrained": - from prismaudio_core.models.pretransforms import AudiocraftCompressionPretransform - - audiocraft_config = pretransform_config["config"] - pretransform = AudiocraftCompressionPretransform(**audiocraft_config) - else: - raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') - - enable_grad = pretransform_config.get('enable_grad', False) - pretransform.enable_grad = enable_grad - - pretransform.eval().requires_grad_(pretransform.enable_grad) - - return pretransform - - -def create_bottleneck_from_config(bottleneck_config): - bottleneck_type = bottleneck_config.get('type', None) - - assert bottleneck_type is not None, 'type must be specified in bottleneck config' - - if bottleneck_type == 'tanh': - from prismaudio_core.models.bottleneck import TanhBottleneck - bottleneck = TanhBottleneck() - elif bottleneck_type == 'vae': - from prismaudio_core.models.bottleneck import VAEBottleneck - bottleneck = VAEBottleneck() - elif bottleneck_type == 'rvq': - from prismaudio_core.models.bottleneck import RVQBottleneck - - quantizer_params = { - "dim": 128, - "codebook_size": 1024, - "num_quantizers": 8, - "decay": 0.99, - "kmeans_init": True, - "kmeans_iters": 50, - "threshold_ema_dead_code": 2, - } - - quantizer_params.update(bottleneck_config["config"]) - - bottleneck = RVQBottleneck(**quantizer_params) - elif bottleneck_type == "dac_rvq": - from prismaudio_core.models.bottleneck import DACRVQBottleneck - - bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) - - elif bottleneck_type == 'rvq_vae': - from prismaudio_core.models.bottleneck import RVQVAEBottleneck - - quantizer_params = { - "dim": 128, - "codebook_size": 1024, - "num_quantizers": 8, - "decay": 0.99, - "kmeans_init": True, - "kmeans_iters": 50, - "threshold_ema_dead_code": 2, - } - - quantizer_params.update(bottleneck_config["config"]) - - bottleneck = RVQVAEBottleneck(**quantizer_params) - - elif bottleneck_type == 'dac_rvq_vae': - from prismaudio_core.models.bottleneck import DACRVQVAEBottleneck - bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) - elif bottleneck_type == 'l2_norm': - from prismaudio_core.models.bottleneck import L2Bottleneck - bottleneck = L2Bottleneck() - elif bottleneck_type == "wasserstein": - from prismaudio_core.models.bottleneck import WassersteinBottleneck - bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) - elif bottleneck_type == "fsq": - from prismaudio_core.models.bottleneck import FSQBottleneck - bottleneck = FSQBottleneck(**bottleneck_config["config"]) - else: - raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') - - requires_grad = bottleneck_config.get('requires_grad', True) - if not requires_grad: - for param in bottleneck.parameters(): - param.requires_grad = False - - return bottleneck - - -def create_autoencoder_from_config(config: Dict[str, Any]): - """Create an AudioAutoencoder from a config dictionary. - - Originally in PrismAudio/models/autoencoders.py. - """ - from prismaudio_core.models.autoencoders import ( - AudioAutoencoder, - create_encoder_from_config, - create_decoder_from_config, - ) - - ae_config = config["model"] - - encoder = create_encoder_from_config(ae_config["encoder"]) - decoder = create_decoder_from_config(ae_config["decoder"]) - - bottleneck = ae_config.get("bottleneck", None) - - latent_dim = ae_config.get("latent_dim", None) - assert latent_dim is not None, "latent_dim must be specified in model config" - downsampling_ratio = ae_config.get("downsampling_ratio", None) - assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" - io_channels = ae_config.get("io_channels", None) - assert io_channels is not None, "io_channels must be specified in model config" - sample_rate = config.get("sample_rate", None) - assert sample_rate is not None, "sample_rate must be specified in model config" - - in_channels = ae_config.get("in_channels", None) - out_channels = ae_config.get("out_channels", None) - - pretransform = ae_config.get("pretransform", None) - - if pretransform is not None: - pretransform = create_pretransform_from_config(pretransform, sample_rate) - - if bottleneck is not None: - bottleneck = create_bottleneck_from_config(bottleneck) - - soft_clip = ae_config["decoder"].get("soft_clip", False) - - return AudioAutoencoder( - encoder, - decoder, - io_channels=io_channels, - latent_dim=latent_dim, - downsampling_ratio=downsampling_ratio, - sample_rate=sample_rate, - bottleneck=bottleneck, - pretransform=pretransform, - in_channels=in_channels, - out_channels=out_channels, - soft_clip=soft_clip - ) - - -def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]): - """Create a MultiConditioner from a conditioning config dictionary. - - Originally in PrismAudio/models/conditioners.py. - """ - from prismaudio_core.models.conditioners import ( - MultiConditioner, - T5Conditioner, - CLAPTextConditioner, - CLIPTextConditioner, - MetaCLIPTextConditioner, - CLAPAudioConditioner, - Cond_MLP, - Global_MLP, - Sync_MLP, - Cond_MLP_1, - Cond_ConvMLP, - Cond_MLP_Global, - Cond_MLP_Global_1, - Cond_MLP_Global_2, - Video_Global, - Video_Sync, - Text_Linear, - CLIPConditioner, - IntConditioner, - NumberConditioner, - PhonemeConditioner, - TokenizerLUTConditioner, - PretransformConditioner, - mm_unchang, - ) - from prismaudio_core.models.utils import load_ckpt_state_dict - - conditioners = {} - cond_dim = config["cond_dim"] - - default_keys = config.get("default_keys", {}) - - for conditioner_info in config["configs"]: - id = conditioner_info["id"] - - conditioner_type = conditioner_info["type"] - - conditioner_config = {"output_dim": cond_dim} - - conditioner_config.update(conditioner_info["config"]) - if conditioner_type == "t5": - conditioners[id] = T5Conditioner(**conditioner_config) - elif conditioner_type == "clap_text": - conditioners[id] = CLAPTextConditioner(**conditioner_config) - elif conditioner_type == "clip_text": - conditioners[id] = CLIPTextConditioner(**conditioner_config) - elif conditioner_type == "metaclip_text": - conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) - elif conditioner_type == "clap_audio": - conditioners[id] = CLAPAudioConditioner(**conditioner_config) - elif conditioner_type == "cond_mlp": - conditioners[id] = Cond_MLP(**conditioner_config) - elif conditioner_type == "global_mlp": - conditioners[id] = Global_MLP(**conditioner_config) - elif conditioner_type == "sync_mlp": - conditioners[id] = Sync_MLP(**conditioner_config) - elif conditioner_type == "cond_mlp_1": - conditioners[id] = Cond_MLP_1(**conditioner_config) - elif conditioner_type == "cond_convmlp": - conditioners[id] = Cond_ConvMLP(**conditioner_config) - elif conditioner_type == "cond_mlp_global": - conditioners[id] = Cond_MLP_Global(**conditioner_config) - elif conditioner_type == "cond_mlp_global_1": - conditioners[id] = Cond_MLP_Global_1(**conditioner_config) - elif conditioner_type == "cond_mlp_global_2": - conditioners[id] = Cond_MLP_Global_2(**conditioner_config) - elif conditioner_type == "video_global": - conditioners[id] = Video_Global(**conditioner_config) - elif conditioner_type == "video_sync": - conditioners[id] = Video_Sync(**conditioner_config) - elif conditioner_type == "text_linear": - conditioners[id] = Text_Linear(**conditioner_config) - elif conditioner_type == "video_clip": - conditioners[id] = CLIPConditioner(**conditioner_config) - elif conditioner_type == "int": - conditioners[id] = IntConditioner(**conditioner_config) - elif conditioner_type == "number": - conditioners[id] = NumberConditioner(**conditioner_config) - elif conditioner_type == "phoneme": - conditioners[id] = PhonemeConditioner(**conditioner_config) - elif conditioner_type == "lut": - conditioners[id] = TokenizerLUTConditioner(**conditioner_config) - elif conditioner_type == "pretransform": - sample_rate = conditioner_config.pop("sample_rate", None) - assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" - - pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) - - if conditioner_config.get("pretransform_ckpt_path", None) is not None: - pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) - - conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) - elif conditioner_type == "mm_unchang": - conditioners[id] = mm_unchang(**conditioner_config) - else: - raise ValueError(f"Unknown conditioner type: {conditioner_type}") - - return MultiConditioner(conditioners, default_keys=default_keys) - - -def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): - """Create a ConditionedDiffusionModelWrapper from a config dictionary. - - Originally in PrismAudio/models/diffusion.py. - """ - from prismaudio_core.models.diffusion import ( - ConditionedDiffusionModelWrapper, - MMConditionedDiffusionModelWrapper, - UNetCFG1DWrapper, - UNet1DCondWrapper, - DiTWrapper, - ) - - model_config = config["model"] - - model_type = config["model_type"] - - diffusion_config = model_config.get('diffusion', None) - assert diffusion_config is not None, "Must specify diffusion config" - - diffusion_model_type = diffusion_config.get('type', None) - assert diffusion_model_type is not None, "Must specify diffusion model type" - - diffusion_model_config = diffusion_config.get('config', None) - assert diffusion_model_config is not None, "Must specify diffusion model config" - - if diffusion_model_type == 'adp_cfg_1d': - diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) - elif diffusion_model_type == 'adp_1d': - diffusion_model = UNet1DCondWrapper(**diffusion_model_config) - elif diffusion_model_type == 'dit': - diffusion_model = DiTWrapper(**diffusion_model_config) - elif diffusion_model_type == 'mmdit': - raise NotImplementedError("mmdit diffusion model type is not supported") - - io_channels = model_config.get('io_channels', None) - assert io_channels is not None, "Must specify io_channels in model config" - - sample_rate = config.get('sample_rate', None) - assert sample_rate is not None, "Must specify sample_rate in config" - - diffusion_objective = diffusion_config.get('diffusion_objective', 'v') - - conditioning_config = model_config.get('conditioning', None) - - conditioner = None - if conditioning_config is not None: - conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) - - cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) - add_cond_ids = diffusion_config.get('add_cond_ids', []) - sync_cond_ids = diffusion_config.get('sync_cond_ids', []) - global_cond_ids = diffusion_config.get('global_cond_ids', []) - input_concat_ids = diffusion_config.get('input_concat_ids', []) - prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) - mm_cond_ids = diffusion_config.get('mm_cond_ids', []) - zero_init = diffusion_config.get('zero_init', False) - pretransform = model_config.get("pretransform", None) - - if pretransform is not None: - pretransform = create_pretransform_from_config(pretransform, sample_rate) - min_input_length = pretransform.downsampling_ratio - else: - min_input_length = 1 - - if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": - min_input_length *= np.prod(diffusion_model_config["factors"]) - elif diffusion_model_type == "dit": - min_input_length *= diffusion_model.model.patch_size - - # Get the proper wrapper class - - extra_kwargs = {} - - if model_type == "mm_diffusion_cond": - wrapper_fn = MMConditionedDiffusionModelWrapper - extra_kwargs["diffusion_objective"] = diffusion_objective - extra_kwargs["mm_cond_ids"] = mm_cond_ids - - if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': - wrapper_fn = ConditionedDiffusionModelWrapper - extra_kwargs["diffusion_objective"] = diffusion_objective - - elif model_type == "diffusion_prior": - raise NotImplementedError("diffusion_prior model type is not supported") - - return wrapper_fn( - diffusion_model, - conditioner, - min_input_length=min_input_length, - sample_rate=sample_rate, - cross_attn_cond_ids=cross_attention_ids, - global_cond_ids=global_cond_ids, - input_concat_ids=input_concat_ids, - prepend_cond_ids=prepend_cond_ids, - add_cond_ids=add_cond_ids, - sync_cond_ids=sync_cond_ids, - pretransform=pretransform, - io_channels=io_channels, - zero_init=zero_init, - **extra_kwargs - ) diff --git a/prismaudio_core/inference/__init__.py b/prismaudio_core/inference/__init__.py deleted file mode 100644 index 9160888..0000000 --- a/prismaudio_core/inference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .sampling import sample_discrete_euler -from .utils import set_audio_channels, prepare_audio - -__all__ = ["sample_discrete_euler", "set_audio_channels", "prepare_audio"] diff --git a/prismaudio_core/inference/sampling.py b/prismaudio_core/inference/sampling.py deleted file mode 100644 index 2326edf..0000000 --- a/prismaudio_core/inference/sampling.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch - - -@torch.no_grad() -def sample_discrete_euler(model, x, steps, sigma_max=1, callback=None, **extra_args): - """Discrete Euler sampler for rectified flow, with optional callback. - - Modified from PrismAudio to add callback parameter for ComfyUI progress reporting. - Original uses tqdm internally. - - Args: - model: The diffusion model (DiTWrapper) - x: Initial noise tensor [B, C, T] - steps: Number of sampling steps - sigma_max: Maximum sigma (default 1.0 for rectified flow) - callback: Optional callable({"i": step, "x": current_x}) for progress - **extra_args: Passed to model() — includes cross_attn_cond, add_cond, - sync_cond, cfg_scale, batch_cfg, etc. - """ - t = torch.linspace(sigma_max, 0, steps + 1, device=x.device, dtype=x.dtype) - - for i, (t_curr, t_next) in enumerate(zip(t[:-1], t[1:])): - dt = t_next - t_curr - t_curr_tensor = t_curr * torch.ones(x.shape[0], dtype=x.dtype, device=x.device) - x = x + dt * model(x, t_curr_tensor, **extra_args) - if callback is not None: - callback({"i": i, "x": x}) - - return x diff --git a/prismaudio_core/inference/utils.py b/prismaudio_core/inference/utils.py deleted file mode 100644 index c47c97b..0000000 --- a/prismaudio_core/inference/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import torch.nn.functional as F -from torchaudio import transforms as T - - -def set_audio_channels(audio, target_channels): - """Convert audio tensor to target number of channels. - - Args: - audio: Audio tensor of shape [B, C, T] - target_channels: Desired number of channels (1 for mono, 2 for stereo) - - Returns: - Audio tensor with the target number of channels. - """ - if target_channels == 1: - # Convert to mono - audio = audio.mean(1, keepdim=True) - elif target_channels == 2: - # Convert to stereo - if audio.shape[1] == 1: - audio = audio.repeat(1, 2, 1) - elif audio.shape[1] > 2: - audio = audio[:, :2, :] - return audio - - -def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): - """Resample, pad/trim, and convert channels of an audio tensor. - - Args: - audio: Audio tensor (1D, 2D [C, T], or 3D [B, C, T]) - in_sr: Input sample rate - target_sr: Target sample rate - target_length: Target length in samples (padded or cropped) - target_channels: Target number of channels - device: Torch device to place the audio on - - Returns: - Audio tensor of shape [B, target_channels, target_length] on device. - """ - audio = audio.to(device) - - if in_sr != target_sr: - resample_tf = T.Resample(in_sr, target_sr).to(device) - audio = resample_tf(audio) - - # Add batch dimension - if audio.dim() == 1: - audio = audio.unsqueeze(0).unsqueeze(0) - elif audio.dim() == 2: - audio = audio.unsqueeze(0) - - # Pad or crop to target_length - if audio.shape[-1] < target_length: - audio = F.pad(audio, (0, target_length - audio.shape[-1])) - elif audio.shape[-1] > target_length: - audio = audio[:, :, :target_length] - - audio = set_audio_channels(audio, target_channels) - - return audio diff --git a/prismaudio_core/models/__init__.py b/prismaudio_core/models/__init__.py deleted file mode 100644 index aed753a..0000000 --- a/prismaudio_core/models/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -""" -PrismAudio model modules for inference. - -Re-exports create_model_from_config from the factory module. -""" - -from prismaudio_core.factory import create_model_from_config - -__all__ = ["create_model_from_config"] diff --git a/prismaudio_core/models/adp.py b/prismaudio_core/models/adp.py deleted file mode 100644 index 49eb526..0000000 --- a/prismaudio_core/models/adp.py +++ /dev/null @@ -1,1588 +0,0 @@ -# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License -# License can be found in LICENSES/LICENSE_ADP.txt - -import math -from inspect import isfunction -from math import ceil, floor, log, pi, log2 -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union -from packaging import version - -import torch -import torch.nn as nn -from einops import rearrange, reduce, repeat -from einops.layers.torch import Rearrange -from einops_exts import rearrange_many -from torch import Tensor, einsum -from torch.backends.cuda import sdp_kernel -from torch.nn import functional as F -from dac.nn.layers import Snake1d - -""" -Utils -""" - - -class ConditionedSequential(nn.Module): - def __init__(self, *modules): - super().__init__() - self.module_list = nn.ModuleList(*modules) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None): - for module in self.module_list: - x = module(x, mapping) - return x - -T = TypeVar("T") - -def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: - if exists(val): - return val - return d() if isfunction(d) else d - -def exists(val: Optional[T]) -> T: - return val is not None - -def closest_power_2(x: float) -> int: - exponent = log2(x) - distance_fn = lambda z: abs(x - 2 ** z) # noqa - exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) - return 2 ** int(exponent_closest) - -def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: - return_dicts: Tuple[Dict, Dict] = ({}, {}) - for key in d.keys(): - no_prefix = int(not key.startswith(prefix)) - return_dicts[no_prefix][key] = d[key] - return return_dicts - -def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: - kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) - if keep_prefix: - return kwargs_with_prefix, kwargs - kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} - return kwargs_no_prefix, kwargs - -""" -Convolutional Blocks -""" -import typing as tp - -# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License -# License available in LICENSES/LICENSE_META.txt - -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: - """See `pad_for_conv1d`.""" - length = x.shape[-1] - n_frames = (length - kernel_size + padding_total) / stride + 1 - ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) - return ideal_length - length - - -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): - """Pad for a convolution to make sure that the last window is full. - Extra padding is added at the end. This is required to ensure that we can rebuild - an output of the same length, as otherwise, even with padding, some time steps - might get removed. - For instance, with total padding = 4, kernel size = 4, stride = 2: - 0 0 1 2 3 4 5 0 0 # (0s are padding) - 1 2 3 # (output frames of a convolution, last 0 is never used) - 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) - 1 2 3 4 # once you removed padding, we are missing one time step ! - """ - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - return F.pad(x, (0, extra_padding)) - - -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): - """Tiny wrapper around F.pad, just to allow for reflect padding on small input. - If this is the case, we insert extra 0 padding to the right before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): - """Remove padding from x, handling properly zero padding. Only for 1d!""" - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - assert (padding_left + padding_right) <= x.shape[-1] - end = x.shape[-1] - padding_right - return x[..., padding_left: end] - - -class Conv1d(nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor, causal=False) -> Tensor: - kernel_size = self.kernel_size[0] - stride = self.stride[0] - dilation = self.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations - padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) - if causal: - # Left padding for causal - x = pad1d(x, (padding_total, extra_padding)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding)) - return super().forward(x) - -class ConvTranspose1d(nn.ConvTranspose1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x: Tensor, causal=False) -> Tensor: - kernel_size = self.kernel_size[0] - stride = self.stride[0] - padding_total = kernel_size - stride - - y = super().forward(x) - - # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be - # removed at the very end, when keeping only the right length for the output, - # as removing it here would require also passing the length at the matching layer - # in the encoder. - if causal: - padding_right = ceil(padding_total) - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - else: - # Asymmetric padding required for odd strides - padding_right = padding_total // 2 - padding_left = padding_total - padding_right - y = unpad1d(y, (padding_left, padding_right)) - return y - - -def Downsample1d( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: - assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" - - return Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * kernel_multiplier + 1, - stride=factor - ) - - -def Upsample1d( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: - - if factor == 1: - return Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3 - ) - - if use_nearest: - return nn.Sequential( - nn.Upsample(scale_factor=factor, mode="nearest"), - Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3 - ), - ) - else: - return ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * 2, - stride=factor - ) - - -class ConvBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - dilation: int = 1, - num_groups: int = 8, - use_norm: bool = True, - use_snake: bool = False - ) -> None: - super().__init__() - - self.groupnorm = ( - nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) - if use_norm - else nn.Identity() - ) - - if use_snake: - self.activation = Snake1d(in_channels) - else: - self.activation = nn.SiLU() - - self.project = Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - - def forward( - self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False - ) -> Tensor: - x = self.groupnorm(x) - if exists(scale_shift): - scale, shift = scale_shift - x = x * (scale + 1) + shift - x = self.activation(x) - return self.project(x, causal=causal) - - -class MappingToScaleShift(nn.Module): - def __init__( - self, - features: int, - channels: int, - ): - super().__init__() - - self.to_scale_shift = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=features, out_features=channels * 2), - ) - - def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: - scale_shift = self.to_scale_shift(mapping) - scale_shift = rearrange(scale_shift, "b c -> b c 1") - scale, shift = scale_shift.chunk(2, dim=1) - return scale, shift - - -class ResnetBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - kernel_size: int = 3, - stride: int = 1, - dilation: int = 1, - use_norm: bool = True, - use_snake: bool = False, - num_groups: int = 8, - context_mapping_features: Optional[int] = None, - ) -> None: - super().__init__() - - self.use_mapping = exists(context_mapping_features) - - self.block1 = ConvBlock1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - use_norm=use_norm, - num_groups=num_groups, - use_snake=use_snake - ) - - if self.use_mapping: - assert exists(context_mapping_features) - self.to_scale_shift = MappingToScaleShift( - features=context_mapping_features, channels=out_channels - ) - - self.block2 = ConvBlock1d( - in_channels=out_channels, - out_channels=out_channels, - use_norm=use_norm, - num_groups=num_groups, - use_snake=use_snake - ) - - self.to_out = ( - Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) - if in_channels != out_channels - else nn.Identity() - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - assert_message = "context mapping required if context_mapping_features > 0" - assert not (self.use_mapping ^ exists(mapping)), assert_message - - h = self.block1(x, causal=causal) - - scale_shift = None - if self.use_mapping: - scale_shift = self.to_scale_shift(mapping) - - h = self.block2(h, scale_shift=scale_shift, causal=causal) - - return h + self.to_out(x) - - -class Patcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - use_snake: bool = False, - ): - super().__init__() - assert_message = f"out_channels must be divisible by patch_size ({patch_size})" - assert out_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels, - out_channels=out_channels // patch_size, - num_groups=1, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - x = self.block(x, mapping, causal=causal) - x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) - return x - - -class Unpatcher(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - patch_size: int, - context_mapping_features: Optional[int] = None, - use_snake: bool = False - ): - super().__init__() - assert_message = f"in_channels must be divisible by patch_size ({patch_size})" - assert in_channels % patch_size == 0, assert_message - self.patch_size = patch_size - - self.block = ResnetBlock1d( - in_channels=in_channels // patch_size, - out_channels=out_channels, - num_groups=1, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: - x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) - x = self.block(x, mapping, causal=causal) - return x - - -""" -Attention Components -""" -def FeedForward(features: int, multiplier: int) -> nn.Module: - mid_features = features * multiplier - return nn.Sequential( - nn.Linear(in_features=features, out_features=mid_features), - nn.GELU(), - nn.Linear(in_features=mid_features, out_features=features), - ) - -def add_mask(sim: Tensor, mask: Tensor) -> Tensor: - b, ndim = sim.shape[0], mask.ndim - if ndim == 3: - mask = rearrange(mask, "b n m -> b 1 n m") - if ndim == 2: - mask = repeat(mask, "n m -> b 1 n m", b=b) - max_neg_value = -torch.finfo(sim.dtype).max - sim = sim.masked_fill(~mask, max_neg_value) - return sim - -def causal_mask(q: Tensor, k: Tensor) -> Tensor: - b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device - mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) - mask = repeat(mask, "n m -> b n m", b=b) - return mask - -class AttentionBase(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - out_features: Optional[int] = None, - ): - super().__init__() - self.scale = head_features**-0.5 - self.num_heads = num_heads - mid_features = head_features * num_heads - out_features = default(out_features, features) - - self.to_out = nn.Linear( - in_features=mid_features, out_features=out_features - ) - - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') - - if not self.use_flash: - return - - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - - if device_properties.major == 8 and device_properties.minor == 0: - # Use flash attention for A100 GPUs - self.sdp_kernel_config = (True, False, False) - else: - # Don't use flash attention for other GPUs - self.sdp_kernel_config = (False, True, True) - - def forward( - self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False - ) -> Tensor: - # Split heads - q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) - - if not self.use_flash: - if is_causal and not mask: - # Mask out future tokens for causal attention - mask = causal_mask(q, k) - - # Compute similarity matrix and add eventual mask - sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale - sim = add_mask(sim, mask) if exists(mask) else sim - - # Get attention matrix with softmax - attn = sim.softmax(dim=-1, dtype=torch.float32) - - # Compute values - out = einsum("... n m, ... m d -> ... n d", attn, v) - else: - with sdp_kernel(*self.sdp_kernel_config): - out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) - - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - -class Attention(nn.Module): - def __init__( - self, - features: int, - *, - head_features: int, - num_heads: int, - out_features: Optional[int] = None, - context_features: Optional[int] = None, - causal: bool = False, - ): - super().__init__() - self.context_features = context_features - self.causal = causal - mid_features = head_features * num_heads - context_features = default(context_features, features) - - self.norm = nn.LayerNorm(features) - self.norm_context = nn.LayerNorm(context_features) - self.to_q = nn.Linear( - in_features=features, out_features=mid_features, bias=False - ) - self.to_kv = nn.Linear( - in_features=context_features, out_features=mid_features * 2, bias=False - ) - self.attention = AttentionBase( - features, - num_heads=num_heads, - head_features=head_features, - out_features=out_features, - ) - - def forward( - self, - x: Tensor, # [b, n, c] - context: Optional[Tensor] = None, # [b, m, d] - context_mask: Optional[Tensor] = None, # [b, m], false is masked, - causal: Optional[bool] = False, - ) -> Tensor: - assert_message = "You must provide a context when using context_features" - assert not self.context_features or exists(context), assert_message - # Use context if provided - context = default(context, x) - # Normalize then compute q from input and k,v from context - x, context = self.norm(x), self.norm_context(context) - - q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) - - if exists(context_mask): - # Mask out cross-attention for padding tokens - mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) - k, v = k * mask, v * mask - - # Compute and return attention - return self.attention(q, k, v, is_causal=self.causal or causal) - - -def FeedForward(features: int, multiplier: int) -> nn.Module: - mid_features = features * multiplier - return nn.Sequential( - nn.Linear(in_features=features, out_features=mid_features), - nn.GELU(), - nn.Linear(in_features=mid_features, out_features=features), - ) - -""" -Transformer Blocks -""" - - -class TransformerBlock(nn.Module): - def __init__( - self, - features: int, - num_heads: int, - head_features: int, - multiplier: int, - context_features: Optional[int] = None, - ): - super().__init__() - - self.use_cross_attention = exists(context_features) and context_features > 0 - - self.attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features - ) - - if self.use_cross_attention: - self.cross_attention = Attention( - features=features, - num_heads=num_heads, - head_features=head_features, - context_features=context_features - ) - - self.feed_forward = FeedForward(features=features, multiplier=multiplier) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: - x = self.attention(x, causal=causal) + x - if self.use_cross_attention: - x = self.cross_attention(x, context=context, context_mask=context_mask) + x - x = self.feed_forward(x) + x - return x - - -""" -Transformers -""" - - -class Transformer1d(nn.Module): - def __init__( - self, - num_layers: int, - channels: int, - num_heads: int, - head_features: int, - multiplier: int, - context_features: Optional[int] = None, - ): - super().__init__() - - self.to_in = nn.Sequential( - nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - Rearrange("b c t -> b t c"), - ) - - self.blocks = nn.ModuleList( - [ - TransformerBlock( - features=channels, - head_features=head_features, - num_heads=num_heads, - multiplier=multiplier, - context_features=context_features, - ) - for i in range(num_layers) - ] - ) - - self.to_out = nn.Sequential( - Rearrange("b t c -> b c t"), - Conv1d( - in_channels=channels, - out_channels=channels, - kernel_size=1, - ), - ) - - def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: - x = self.to_in(x) - for block in self.blocks: - x = block(x, context=context, context_mask=context_mask, causal=causal) - x = self.to_out(x) - return x - - -""" -Time Embeddings -""" - - -class SinusoidalEmbedding(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, x: Tensor) -> Tensor: - device, half_dim = x.device, self.dim // 2 - emb = torch.tensor(log(10000) / (half_dim - 1), device=device) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") - return torch.cat((emb.sin(), emb.cos()), dim=-1) - - -class LearnedPositionalEmbedding(nn.Module): - """Used for continuous time""" - - def __init__(self, dim: int): - super().__init__() - assert (dim % 2) == 0 - half_dim = dim // 2 - self.weights = nn.Parameter(torch.randn(half_dim)) - - def forward(self, x: Tensor) -> Tensor: - x = rearrange(x, "b -> b 1") - freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi - fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) - fouriered = torch.cat((x, fouriered), dim=-1) - return fouriered - - -def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: - return nn.Sequential( - LearnedPositionalEmbedding(dim), - nn.Linear(in_features=dim + 1, out_features=out_features), - ) - - -""" -Encoder/Decoder Components -""" - - -class DownsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_groups: int, - num_layers: int, - kernel_multiplier: int = 2, - use_pre_downsample: bool = True, - use_skip: bool = False, - use_snake: bool = False, - extract_channels: int = 0, - context_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - self.use_pre_downsample = use_pre_downsample - self.use_skip = use_skip - self.use_transformer = num_transformer_blocks > 0 - self.use_extract = extract_channels > 0 - self.use_context = context_channels > 0 - - channels = out_channels if use_pre_downsample else in_channels - - self.downsample = Downsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - kernel_multiplier=kernel_multiplier, - ) - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + context_channels if i == 0 else channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - for i in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - use_snake=use_snake - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - channels: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: - - if self.use_pre_downsample: - x = self.downsample(x) - - if self.use_context and exists(channels): - x = torch.cat([x, channels], dim=1) - - skips = [] - for block in self.blocks: - x = block(x, mapping=mapping, causal=causal) - skips += [x] if self.use_skip else [] - - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - skips += [x] if self.use_skip else [] - - if not self.use_pre_downsample: - x = self.downsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return (x, skips) if self.use_skip else x - - -class UpsampleBlock1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - *, - factor: int, - num_layers: int, - num_groups: int, - use_nearest: bool = False, - use_pre_upsample: bool = False, - use_skip: bool = False, - use_snake: bool = False, - skip_channels: int = 0, - use_skip_scale: bool = False, - extract_channels: int = 0, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - ): - super().__init__() - - self.use_extract = extract_channels > 0 - self.use_pre_upsample = use_pre_upsample - self.use_transformer = num_transformer_blocks > 0 - self.use_skip = use_skip - self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 - - channels = out_channels if use_pre_upsample else in_channels - - self.blocks = nn.ModuleList( - [ - ResnetBlock1d( - in_channels=channels + skip_channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - for _ in range(num_layers) - ] - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - ) - - self.upsample = Upsample1d( - in_channels=in_channels, - out_channels=out_channels, - factor=factor, - use_nearest=use_nearest, - ) - - if self.use_extract: - num_extract_groups = min(num_groups, extract_channels) - self.to_extracted = ResnetBlock1d( - in_channels=out_channels, - out_channels=extract_channels, - num_groups=num_extract_groups, - use_snake=use_snake - ) - - def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: - return torch.cat([x, skip * self.skip_scale], dim=1) - - def forward( - self, - x: Tensor, - *, - skips: Optional[List[Tensor]] = None, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Union[Tuple[Tensor, Tensor], Tensor]: - - if self.use_pre_upsample: - x = self.upsample(x) - - for block in self.blocks: - x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x - x = block(x, mapping=mapping, causal=causal) - - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - - if not self.use_pre_upsample: - x = self.upsample(x) - - if self.use_extract: - extracted = self.to_extracted(x) - return x, extracted - - return x - - -class BottleneckBlock1d(nn.Module): - def __init__( - self, - channels: int, - *, - num_groups: int, - num_transformer_blocks: int = 0, - attention_heads: Optional[int] = None, - attention_features: Optional[int] = None, - attention_multiplier: Optional[int] = None, - context_mapping_features: Optional[int] = None, - context_embedding_features: Optional[int] = None, - use_snake: bool = False, - ): - super().__init__() - self.use_transformer = num_transformer_blocks > 0 - - self.pre_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - if self.use_transformer: - assert ( - (exists(attention_heads) or exists(attention_features)) - and exists(attention_multiplier) - ) - - if attention_features is None and attention_heads is not None: - attention_features = channels // attention_heads - - if attention_heads is None and attention_features is not None: - attention_heads = channels // attention_features - - self.transformer = Transformer1d( - num_layers=num_transformer_blocks, - channels=channels, - num_heads=attention_heads, - head_features=attention_features, - multiplier=attention_multiplier, - context_features=context_embedding_features, - ) - - self.post_block = ResnetBlock1d( - in_channels=channels, - out_channels=channels, - num_groups=num_groups, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def forward( - self, - x: Tensor, - *, - mapping: Optional[Tensor] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False - ) -> Tensor: - x = self.pre_block(x, mapping=mapping, causal=causal) - if self.use_transformer: - x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) - x = self.post_block(x, mapping=mapping, causal=causal) - return x - - -""" -UNet -""" - - -class UNet1d(nn.Module): - def __init__( - self, - in_channels: int, - channels: int, - multipliers: Sequence[int], - factors: Sequence[int], - num_blocks: Sequence[int], - attentions: Sequence[int], - patch_size: int = 1, - resnet_groups: int = 8, - use_context_time: bool = True, - kernel_multiplier_downsample: int = 2, - use_nearest_upsample: bool = False, - use_skip_scale: bool = True, - use_snake: bool = False, - use_stft: bool = False, - use_stft_context: bool = False, - out_channels: Optional[int] = None, - context_features: Optional[int] = None, - context_features_multiplier: int = 4, - context_channels: Optional[Sequence[int]] = None, - context_embedding_features: Optional[int] = None, - **kwargs, - ): - super().__init__() - out_channels = default(out_channels, in_channels) - context_channels = list(default(context_channels, [])) - num_layers = len(multipliers) - 1 - use_context_features = exists(context_features) - use_context_channels = len(context_channels) > 0 - context_mapping_features = None - - attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) - - self.num_layers = num_layers - self.use_context_time = use_context_time - self.use_context_features = use_context_features - self.use_context_channels = use_context_channels - self.use_stft = use_stft - self.use_stft_context = use_stft_context - - self.context_features = context_features - context_channels_pad_length = num_layers + 1 - len(context_channels) - context_channels = context_channels + [0] * context_channels_pad_length - self.context_channels = context_channels - self.context_embedding_features = context_embedding_features - - if use_context_channels: - has_context = [c > 0 for c in context_channels] - self.has_context = has_context - self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] - - assert ( - len(factors) == num_layers - and len(attentions) >= num_layers - and len(num_blocks) == num_layers - ) - - if use_context_time or use_context_features: - context_mapping_features = channels * context_features_multiplier - - self.to_mapping = nn.Sequential( - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - nn.Linear(context_mapping_features, context_mapping_features), - nn.GELU(), - ) - - if use_context_time: - assert exists(context_mapping_features) - self.to_time = nn.Sequential( - TimePositionalEmbedding( - dim=channels, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_context_features: - assert exists(context_features) and exists(context_mapping_features) - self.to_features = nn.Sequential( - nn.Linear( - in_features=context_features, out_features=context_mapping_features - ), - nn.GELU(), - ) - - if use_stft: - stft_kwargs, kwargs = groupby("stft_", kwargs) - assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" - stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 - in_channels *= stft_channels - out_channels *= stft_channels - context_channels[0] *= stft_channels if use_stft_context else 1 - assert exists(in_channels) and exists(out_channels) - self.stft = STFT(**stft_kwargs) - - assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" - - self.to_in = Patcher( - in_channels=in_channels + context_channels[0], - out_channels=channels * multipliers[0], - patch_size=patch_size, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - self.downsamples = nn.ModuleList( - [ - DownsampleBlock1d( - in_channels=channels * multipliers[i], - out_channels=channels * multipliers[i + 1], - context_mapping_features=context_mapping_features, - context_channels=context_channels[i + 1], - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i], - factor=factors[i], - kernel_multiplier=kernel_multiplier_downsample, - num_groups=resnet_groups, - use_pre_downsample=True, - use_skip=True, - use_snake=use_snake, - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in range(num_layers) - ] - ) - - self.bottleneck = BottleneckBlock1d( - channels=channels * multipliers[-1], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_groups=resnet_groups, - num_transformer_blocks=attentions[-1], - use_snake=use_snake, - **attention_kwargs, - ) - - self.upsamples = nn.ModuleList( - [ - UpsampleBlock1d( - in_channels=channels * multipliers[i + 1], - out_channels=channels * multipliers[i], - context_mapping_features=context_mapping_features, - context_embedding_features=context_embedding_features, - num_layers=num_blocks[i] + (1 if attentions[i] else 0), - factor=factors[i], - use_nearest=use_nearest_upsample, - num_groups=resnet_groups, - use_skip_scale=use_skip_scale, - use_pre_upsample=False, - use_skip=True, - use_snake=use_snake, - skip_channels=channels * multipliers[i + 1], - num_transformer_blocks=attentions[i], - **attention_kwargs, - ) - for i in reversed(range(num_layers)) - ] - ) - - self.to_out = Unpatcher( - in_channels=channels * multipliers[0], - out_channels=out_channels, - patch_size=patch_size, - context_mapping_features=context_mapping_features, - use_snake=use_snake - ) - - def get_channels( - self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 - ) -> Optional[Tensor]: - """Gets context channels at `layer` and checks that shape is correct""" - use_context_channels = self.use_context_channels and self.has_context[layer] - if not use_context_channels: - return None - assert exists(channels_list), "Missing context" - # Get channels index (skipping zero channel contexts) - channels_id = self.channels_ids[layer] - # Get channels - channels = channels_list[channels_id] - message = f"Missing context for layer {layer} at index {channels_id}" - assert exists(channels), message - # Check channels - num_channels = self.context_channels[layer] - message = f"Expected context with {num_channels} channels at idx {channels_id}" - assert channels.shape[1] == num_channels, message - # STFT channels if requested - channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa - return channels - - def get_mapping( - self, time: Optional[Tensor] = None, features: Optional[Tensor] = None - ) -> Optional[Tensor]: - """Combines context time features and features into mapping""" - items, mapping = [], None - # Compute time features - if self.use_context_time: - assert_message = "use_context_time=True but no time features provided" - assert exists(time), assert_message - items += [self.to_time(time)] - # Compute features - if self.use_context_features: - assert_message = "context_features exists but no features provided" - assert exists(features), assert_message - items += [self.to_features(features)] - # Compute joint mapping - if self.use_context_time or self.use_context_features: - mapping = reduce(torch.stack(items), "n b m -> b m", "sum") - mapping = self.to_mapping(mapping) - return mapping - - def forward( - self, - x: Tensor, - time: Optional[Tensor] = None, - *, - features: Optional[Tensor] = None, - channels_list: Optional[Sequence[Tensor]] = None, - embedding: Optional[Tensor] = None, - embedding_mask: Optional[Tensor] = None, - causal: Optional[bool] = False, - ) -> Tensor: - channels = self.get_channels(channels_list, layer=0) - # Apply stft if required - x = self.stft.encode1d(x) if self.use_stft else x # type: ignore - # Concat context channels at layer 0 if provided - x = torch.cat([x, channels], dim=1) if exists(channels) else x - # Compute mapping from time and features - mapping = self.get_mapping(time, features) - x = self.to_in(x, mapping, causal=causal) - skips_list = [x] - - for i, downsample in enumerate(self.downsamples): - channels = self.get_channels(channels_list, layer=i + 1) - x, skips = downsample( - x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal - ) - skips_list += [skips] - - x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) - - for i, upsample in enumerate(self.upsamples): - skips = skips_list.pop() - x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) - - x += skips_list.pop() - x = self.to_out(x, mapping, causal=causal) - x = self.stft.decode1d(x) if self.use_stft else x - - return x - - -""" Conditioning Modules """ - - -class FixedEmbedding(nn.Module): - def __init__(self, max_length: int, features: int): - super().__init__() - self.max_length = max_length - self.embedding = nn.Embedding(max_length, features) - - def forward(self, x: Tensor) -> Tensor: - batch_size, length, device = *x.shape[0:2], x.device - assert_message = "Input sequence length must be <= max_length" - assert length <= self.max_length, assert_message - position = torch.arange(length, device=device) - fixed_embedding = self.embedding(position) - fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) - return fixed_embedding - - -def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: - if proba == 1: - return torch.ones(shape, device=device, dtype=torch.bool) - elif proba == 0: - return torch.zeros(shape, device=device, dtype=torch.bool) - else: - return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) - - -class UNetCFG1d(UNet1d): - - """UNet1d with Classifier-Free Guidance""" - - def __init__( - self, - context_embedding_max_length: int, - context_embedding_features: int, - use_xattn_time: bool = False, - **kwargs, - ): - super().__init__( - context_embedding_features=context_embedding_features, **kwargs - ) - - self.use_xattn_time = use_xattn_time - - if use_xattn_time: - assert exists(context_embedding_features) - self.to_time_embedding = nn.Sequential( - TimePositionalEmbedding( - dim=kwargs["channels"], out_features=context_embedding_features - ), - nn.GELU(), - ) - - context_embedding_max_length += 1 # Add one for time embedding - - self.fixed_embedding = FixedEmbedding( - max_length=context_embedding_max_length, features=context_embedding_features - ) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - embedding: Tensor, - embedding_mask: Optional[Tensor] = None, - embedding_scale: float = 1.0, - embedding_mask_proba: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - scale_phi: float = 0.4, - negative_embedding: Optional[Tensor] = None, - negative_embedding_mask: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - b, device = embedding.shape[0], embedding.device - - if self.use_xattn_time: - embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) - - if embedding_mask is not None: - embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) - - fixed_embedding = self.fixed_embedding(embedding) - - if embedding_mask_proba > 0.0: - # Randomly mask embedding - batch_mask = rand_bool( - shape=(b, 1, 1), proba=embedding_mask_proba, device=device - ) - embedding = torch.where(batch_mask, fixed_embedding, embedding) - - if embedding_scale != 1.0: - if batch_cfg: - batch_x = torch.cat([x, x], dim=0) - batch_time = torch.cat([time, time], dim=0) - - if negative_embedding is not None: - if negative_embedding_mask is not None: - negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) - - negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) - - batch_embed = torch.cat([embedding, negative_embedding], dim=0) - - else: - batch_embed = torch.cat([embedding, fixed_embedding], dim=0) - - batch_mask = None - if embedding_mask is not None: - batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) - - batch_features = None - features = kwargs.pop("features", None) - if self.use_context_features: - batch_features = torch.cat([features, features], dim=0) - - batch_channels = None - channels_list = kwargs.pop("channels_list", None) - if self.use_context_channels: - batch_channels = [] - for channels in channels_list: - batch_channels += [torch.cat([channels, channels], dim=0)] - - # Compute both normal and fixed embedding outputs - batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) - out, out_masked = batch_out.chunk(2, dim=0) - - else: - # Compute both normal and fixed embedding outputs - out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) - out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) - - out_cfg = out_masked + (out - out_masked) * embedding_scale - - if rescale_cfg: - - out_std = out.std(dim=1, keepdim=True) - out_cfg_std = out_cfg.std(dim=1, keepdim=True) - - return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg - - else: - - return out_cfg - - else: - return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) - - -class UNetNCCA1d(UNet1d): - - """UNet1d with Noise Channel Conditioning Augmentation""" - - def __init__(self, context_features: int, **kwargs): - super().__init__(context_features=context_features, **kwargs) - self.embedder = NumberEmbedder(features=context_features) - - def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: - x = x if torch.is_tensor(x) else torch.tensor(x) - return x.expand(shape) - - def forward( # type: ignore - self, - x: Tensor, - time: Tensor, - *, - channels_list: Sequence[Tensor], - channels_augmentation: Union[ - bool, Sequence[bool], Sequence[Sequence[bool]], Tensor - ] = False, - channels_scale: Union[ - float, Sequence[float], Sequence[Sequence[float]], Tensor - ] = 0, - **kwargs, - ) -> Tensor: - b, n = x.shape[0], len(channels_list) - channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) - channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) - - # Augmentation (for each channel list item) - for i in range(n): - scale = channels_scale[:, i] * channels_augmentation[:, i] - scale = rearrange(scale, "b -> b 1 1") - item = channels_list[i] - channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa - - # Scale embedding (sum reduction if more than one channel list item) - channels_scale_emb = self.embedder(channels_scale) - channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") - - return super().forward( - x=x, - time=time, - channels_list=channels_list, - features=channels_scale_emb, - **kwargs, - ) - - -class UNetAll1d(UNetCFG1d, UNetNCCA1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): # type: ignore - return UNetCFG1d.forward(self, *args, **kwargs) - - -def XUNet1d(type: str = "base", **kwargs) -> UNet1d: - if type == "base": - return UNet1d(**kwargs) - elif type == "all": - return UNetAll1d(**kwargs) - elif type == "cfg": - return UNetCFG1d(**kwargs) - elif type == "ncca": - return UNetNCCA1d(**kwargs) - else: - raise ValueError(f"Unknown XUNet1d type: {type}") - -class NumberEmbedder(nn.Module): - def __init__( - self, - features: int, - dim: int = 256, - ): - super().__init__() - self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) - - def forward(self, x: Union[List[float], Tensor]) -> Tensor: - if not torch.is_tensor(x): - device = next(self.embedding.parameters()).device - x = torch.tensor(x, device=device) - assert isinstance(x, Tensor) - shape = x.shape - x = rearrange(x, "... -> (...)") - embedding = self.embedding(x) - x = embedding.view(*shape, self.features) - return x # type: ignore - - -""" -Audio Transforms -""" - - -class STFT(nn.Module): - """Helper for torch stft and istft""" - - def __init__( - self, - num_fft: int = 1023, - hop_length: int = 256, - window_length: Optional[int] = None, - length: Optional[int] = None, - use_complex: bool = False, - ): - super().__init__() - self.num_fft = num_fft - self.hop_length = default(hop_length, floor(num_fft // 4)) - self.window_length = default(window_length, num_fft) - self.length = length - self.register_buffer("window", torch.hann_window(self.window_length)) - self.use_complex = use_complex - - def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: - b = wave.shape[0] - wave = rearrange(wave, "b c t -> (b c) t") - - stft = torch.stft( - wave, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - return_complex=True, - normalized=True, - ) - - if self.use_complex: - # Returns real and imaginary - stft_a, stft_b = stft.real, stft.imag - else: - # Returns magnitude and phase matrices - magnitude, phase = torch.abs(stft), torch.angle(stft) - stft_a, stft_b = magnitude, phase - - return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) - - def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: - b, l = stft_a.shape[0], stft_a.shape[-1] # noqa - length = closest_power_2(l * self.hop_length) - - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") - - if self.use_complex: - real, imag = stft_a, stft_b - else: - magnitude, phase = stft_a, stft_b - real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) - - stft = torch.stack([real, imag], dim=-1) - - wave = torch.istft( - stft, - n_fft=self.num_fft, - hop_length=self.hop_length, - win_length=self.window_length, - window=self.window, # type: ignore - length=default(self.length, length), - normalized=True, - ) - - return rearrange(wave, "(b c) t -> b c t", b=b) - - def encode1d( - self, wave: Tensor, stacked: bool = True - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - stft_a, stft_b = self.encode(wave) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") - return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) - - def decode1d(self, stft_pair: Tensor) -> Tensor: - f = self.num_fft // 2 + 1 - stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) - stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) - return self.decode(stft_a, stft_b) diff --git a/prismaudio_core/models/autoencoders.py b/prismaudio_core/models/autoencoders.py deleted file mode 100644 index 5a4b45d..0000000 --- a/prismaudio_core/models/autoencoders.py +++ /dev/null @@ -1,821 +0,0 @@ -import torch -import math -import numpy as np - -from torch import nn -from torch.nn import functional as F -from torchaudio import transforms as T -from alias_free_torch import Activation1d -from dac.nn.layers import WNConv1d, WNConvTranspose1d -from typing import Literal, Dict, Any - -from .blocks import SnakeBeta -from .bottleneck import Bottleneck, DiscreteBottleneck -from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper -from .pretransforms import Pretransform -from .utils import checkpoint - - -def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): - """Minimal stub for inference.utils.prepare_audio used by autoencoders.""" - import torchaudio.transforms as T - import torch - - if in_sr != target_sr: - resample_tf = T.Resample(in_sr, target_sr).to(device) - audio = resample_tf(audio) - - if audio.shape[0] > target_channels: - audio = audio[:target_channels] - elif audio.shape[0] < target_channels: - audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels] - - if audio.shape[-1] < target_length: - audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1])) - elif audio.shape[-1] > target_length: - audio = audio[..., :target_length] - - return audio.unsqueeze(0) - - -def _lazy_create_pretransform_from_config(pretransform, sample_rate): - from prismaudio_core.factory import create_pretransform_from_config - return create_pretransform_from_config(pretransform, sample_rate) - - -def _lazy_create_bottleneck_from_config(bottleneck): - from prismaudio_core.factory import create_bottleneck_from_config - return create_bottleneck_from_config(bottleneck) - -def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: - if activation == "elu": - act = nn.ELU() - elif activation == "snake": - act = SnakeBeta(channels) - elif activation == "none": - act = nn.Identity() - else: - raise ValueError(f"Unknown activation {activation}") - - if antialias: - act = Activation1d(act) - - return act - -class ResidualUnit(nn.Module): - def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): - super().__init__() - - self.dilation = dilation - - padding = (dilation * (7-1)) // 2 - - self.layers = nn.Sequential( - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), - WNConv1d(in_channels=in_channels, out_channels=out_channels, - kernel_size=7, dilation=dilation, padding=padding), - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), - WNConv1d(in_channels=out_channels, out_channels=out_channels, - kernel_size=1) - ) - - def forward(self, x): - res = x - - #x = checkpoint(self.layers, x) - x = self.layers(x) - - return x + res - -class EncoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): - super().__init__() - - self.layers = nn.Sequential( - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=1, use_snake=use_snake), - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=3, use_snake=use_snake), - ResidualUnit(in_channels=in_channels, - out_channels=in_channels, dilation=9, use_snake=use_snake), - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), - WNConv1d(in_channels=in_channels, out_channels=out_channels, - kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), - ) - - def forward(self, x): - return self.layers(x) - -class DecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): - super().__init__() - - if use_nearest_upsample: - upsample_layer = nn.Sequential( - nn.Upsample(scale_factor=stride, mode="nearest"), - WNConv1d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=2*stride, - stride=1, - bias=False, - padding='same') - ) - else: - upsample_layer = WNConvTranspose1d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) - - self.layers = nn.Sequential( - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), - upsample_layer, - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=1, use_snake=use_snake), - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=3, use_snake=use_snake), - ResidualUnit(in_channels=out_channels, out_channels=out_channels, - dilation=9, use_snake=use_snake), - ) - - def forward(self, x): - return self.layers(x) - -class OobleckEncoder(nn.Module): - def __init__(self, - in_channels=2, - channels=128, - latent_dim=32, - c_mults = [1, 2, 4, 8], - strides = [2, 4, 8, 8], - use_snake=False, - antialias_activation=False - ): - super().__init__() - - c_mults = [1] + c_mults - - self.depth = len(c_mults) - - layers = [ - WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) - ] - - for i in range(self.depth-1): - layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] - - layers += [ - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), - WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) - ] - - self.layers = nn.Sequential(*layers) - - def forward(self, x): - return self.layers(x) - - -class OobleckDecoder(nn.Module): - def __init__(self, - out_channels=2, - channels=128, - latent_dim=32, - c_mults = [1, 2, 4, 8], - strides = [2, 4, 8, 8], - use_snake=False, - antialias_activation=False, - use_nearest_upsample=False, - final_tanh=True): - super().__init__() - - c_mults = [1] + c_mults - - self.depth = len(c_mults) - - layers = [ - WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), - ] - - for i in range(self.depth-1, 0, -1): - layers += [DecoderBlock( - in_channels=c_mults[i]*channels, - out_channels=c_mults[i-1]*channels, - stride=strides[i-1], - use_snake=use_snake, - antialias_activation=antialias_activation, - use_nearest_upsample=use_nearest_upsample - ) - ] - - layers += [ - get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), - WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), - nn.Tanh() if final_tanh else nn.Identity() - ] - - self.layers = nn.Sequential(*layers) - - def forward(self, x): - return self.layers(x) - - -class DACEncoderWrapper(nn.Module): - def __init__(self, in_channels=1, **kwargs): - super().__init__() - - from dac.model.dac import Encoder as DACEncoder - - latent_dim = kwargs.pop("latent_dim", None) - - encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) - self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) - self.latent_dim = latent_dim - - # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility - self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() - - if in_channels != 1: - self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) - - def forward(self, x): - x = self.encoder(x) - x = self.proj_out(x) - return x - -class DACDecoderWrapper(nn.Module): - def __init__(self, latent_dim, out_channels=1, **kwargs): - super().__init__() - - from dac.model.dac import Decoder as DACDecoder - - self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) - - self.latent_dim = latent_dim - - def forward(self, x): - return self.decoder(x) - -class AudioAutoencoder(nn.Module): - def __init__( - self, - encoder, - decoder, - latent_dim, - downsampling_ratio, - sample_rate, - io_channels=2, - bottleneck: Bottleneck = None, - pretransform: Pretransform = None, - in_channels = None, - out_channels = None, - soft_clip = False - ): - super().__init__() - - self.downsampling_ratio = downsampling_ratio - self.sample_rate = sample_rate - - self.latent_dim = latent_dim - self.io_channels = io_channels - self.in_channels = io_channels - self.out_channels = io_channels - - self.min_length = self.downsampling_ratio - - if in_channels is not None: - self.in_channels = in_channels - - if out_channels is not None: - self.out_channels = out_channels - - self.bottleneck = bottleneck - - self.encoder = encoder - - self.decoder = decoder - - self.pretransform = pretransform - - self.soft_clip = soft_clip - - self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete - - def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): - - info = {} - - if self.pretransform is not None and not skip_pretransform: - if self.pretransform.enable_grad: - if iterate_batch: - audios = [] - for i in range(audio.shape[0]): - audios.append(self.pretransform.encode(audio[i:i+1])) - audio = torch.cat(audios, dim=0) - else: - audio = self.pretransform.encode(audio) - else: - with torch.no_grad(): - if iterate_batch: - audios = [] - for i in range(audio.shape[0]): - audios.append(self.pretransform.encode(audio[i:i+1])) - audio = torch.cat(audios, dim=0) - else: - audio = self.pretransform.encode(audio) - - if self.encoder is not None: - if iterate_batch: - latents = [] - for i in range(audio.shape[0]): - latents.append(self.encoder(audio[i:i+1])) - latents = torch.cat(latents, dim=0) - else: - latents = self.encoder(audio) - else: - latents = audio - - if self.bottleneck is not None: - # TODO: Add iterate batch logic, needs to merge the info dicts - latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) - - info.update(bottleneck_info) - - if return_info: - return latents, info - - return latents - - def decode(self, latents, iterate_batch=False, **kwargs): - - if self.bottleneck is not None: - if iterate_batch: - decoded = [] - for i in range(latents.shape[0]): - decoded.append(self.bottleneck.decode(latents[i:i+1])) - latents = torch.cat(decoded, dim=0) - else: - latents = self.bottleneck.decode(latents) - - if iterate_batch: - decoded = [] - for i in range(latents.shape[0]): - decoded.append(self.decoder(latents[i:i+1])) - decoded = torch.cat(decoded, dim=0) - else: - decoded = self.decoder(latents, **kwargs) - - if self.pretransform is not None: - if self.pretransform.enable_grad: - if iterate_batch: - decodeds = [] - for i in range(decoded.shape[0]): - decodeds.append(self.pretransform.decode(decoded[i:i+1])) - decoded = torch.cat(decodeds, dim=0) - else: - decoded = self.pretransform.decode(decoded) - else: - with torch.no_grad(): - if iterate_batch: - decodeds = [] - for i in range(latents.shape[0]): - decodeds.append(self.pretransform.decode(decoded[i:i+1])) - decoded = torch.cat(decodeds, dim=0) - else: - decoded = self.pretransform.decode(decoded) - - if self.soft_clip: - decoded = torch.tanh(decoded) - - return decoded - - def decode_tokens(self, tokens, **kwargs): - ''' - Decode discrete tokens to audio - Only works with discrete autoencoders - ''' - - assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" - - latents = self.bottleneck.decode_tokens(tokens, **kwargs) - - return self.decode(latents, **kwargs) - - - def preprocess_audio_for_encoder(self, audio, in_sr): - ''' - Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. - If the model is mono, stereo audio will be converted to mono. - Audio will be silence-padded to be a multiple of the model's downsampling ratio. - Audio will be resampled to the model's sample rate. - The output will have batch size 1 and be shape (1 x Channels x Length) - ''' - return self.preprocess_audio_list_for_encoder([audio], [in_sr]) - - def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): - ''' - Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. - The audio in that list can be of different lengths and channels. - in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. - All audio will be resampled to the model's sample rate. - Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. - If the model is mono, all audio will be converted to mono. - The output will be a tensor of shape (Batch x Channels x Length) - ''' - batch_size = len(audio_list) - if isinstance(in_sr_list, int): - in_sr_list = [in_sr_list]*batch_size - assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" - new_audio = [] - max_length = 0 - # resample & find the max length - for i in range(batch_size): - audio = audio_list[i] - in_sr = in_sr_list[i] - if len(audio.shape) == 3 and audio.shape[0] == 1: - # batchsize 1 was given by accident. Just squeeze it. - audio = audio.squeeze(0) - elif len(audio.shape) == 1: - # Mono signal, channel dimension is missing, unsqueeze it in - audio = audio.unsqueeze(0) - assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" - # Resample audio - if in_sr != self.sample_rate: - resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) - audio = resample_tf(audio) - new_audio.append(audio) - if audio.shape[-1] > max_length: - max_length = audio.shape[-1] - # Pad every audio to the same length, multiple of model's downsampling ratio - padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length - for i in range(batch_size): - # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model - new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, - target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) - # convert to tensor - return torch.stack(new_audio) - - def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): - ''' - Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. - If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. - Overlap and chunk_size params are both measured in number of latents (not audio samples) - # and therefore you likely could use the same values with decode_audio. - A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. - Every autoencoder will have a different receptive field size, and thus ideal overlap. - You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. - The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. - Smaller chunk_size uses less memory, but more compute. - The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version - For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks - ''' - if not chunked: - # default behavior. Encode the entire audio in parallel - return self.encode(audio, **kwargs) - else: - # CHUNKED ENCODING - # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) - samples_per_latent = self.downsampling_ratio - total_size = audio.shape[2] # in samples - batch_size = audio.shape[0] - chunk_size *= samples_per_latent # converting metric in latents to samples - overlap *= samples_per_latent # converting metric in latents to samples - hop_size = chunk_size - overlap - chunks = [] - for i in range(0, total_size - chunk_size + 1, hop_size): - chunk = audio[:,:,i:i+chunk_size] - chunks.append(chunk) - if i+chunk_size != total_size: - # Final chunk - chunk = audio[:,:,-chunk_size:] - chunks.append(chunk) - chunks = torch.stack(chunks) - num_chunks = chunks.shape[0] - # Note: y_size might be a different value from the latent length used in diffusion training - # because we can encode audio of varying lengths - # However, the audio should've been padded to a multiple of samples_per_latent by now. - y_size = total_size // samples_per_latent - # Create an empty latent, we will populate it with chunks as we encode them - y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) - for i in range(num_chunks): - x_chunk = chunks[i,:] - # encode the chunk - y_chunk = self.encode(x_chunk) - # figure out where to put the audio along the time domain - if i == num_chunks-1: - # final chunk always goes at the end - t_end = y_size - t_start = t_end - y_chunk.shape[2] - else: - t_start = i * hop_size // samples_per_latent - t_end = t_start + chunk_size // samples_per_latent - # remove the edges of the overlaps - ol = overlap//samples_per_latent//2 - chunk_start = 0 - chunk_end = y_chunk.shape[2] - if i > 0: - # no overlap for the start of the first chunk - t_start += ol - chunk_start += ol - if i < num_chunks-1: - # no overlap for the end of the last chunk - t_end -= ol - chunk_end -= ol - # paste the chunked audio into our y_final output audio - y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] - return y_final - - def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): - ''' - Decode latents to audio. - If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. - A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. - Every autoencoder will have a different receptive field size, and thus ideal overlap. - You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. - The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. - Smaller chunk_size uses less memory, but more compute. - The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version - For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks - ''' - if not chunked: - # default behavior. Decode the entire latent in parallel - return self.decode(latents, **kwargs) - else: - # chunked decoding - hop_size = chunk_size - overlap - total_size = latents.shape[2] - batch_size = latents.shape[0] - chunks = [] - for i in range(0, total_size - chunk_size + 1, hop_size): - chunk = latents[:,:,i:i+chunk_size] - chunks.append(chunk) - if i+chunk_size != total_size: - # Final chunk - chunk = latents[:,:,-chunk_size:] - chunks.append(chunk) - chunks = torch.stack(chunks) - num_chunks = chunks.shape[0] - # samples_per_latent is just the downsampling ratio - samples_per_latent = self.downsampling_ratio - # Create an empty waveform, we will populate it with chunks as decode them - y_size = total_size * samples_per_latent - y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) - for i in range(num_chunks): - x_chunk = chunks[i,:] - # decode the chunk - y_chunk = self.decode(x_chunk) - # figure out where to put the audio along the time domain - if i == num_chunks-1: - # final chunk always goes at the end - t_end = y_size - t_start = t_end - y_chunk.shape[2] - else: - t_start = i * hop_size * samples_per_latent - t_end = t_start + chunk_size * samples_per_latent - # remove the edges of the overlaps - ol = (overlap//2) * samples_per_latent - chunk_start = 0 - chunk_end = y_chunk.shape[2] - if i > 0: - # no overlap for the start of the first chunk - t_start += ol - chunk_start += ol - if i < num_chunks-1: - # no overlap for the end of the last chunk - t_end -= ol - chunk_end -= ol - # paste the chunked audio into our y_final output audio - y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] - return y_final - - -class DiffusionAutoencoder(AudioAutoencoder): - def __init__( - self, - diffusion: ConditionedDiffusionModel, - diffusion_downsampling_ratio, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - - self.diffusion = diffusion - - self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio - - if self.encoder is not None: - # Shrink the initial encoder parameters to avoid saturated latents - with torch.no_grad(): - for param in self.encoder.parameters(): - param *= 0.5 - - def decode(self, latents, steps=100): - - upsampled_length = latents.shape[2] * self.downsampling_ratio - - if self.bottleneck is not None: - latents = self.bottleneck.decode(latents) - - if self.decoder is not None: - latents = self.decoder(latents) - - # Upsample latents to match diffusion length - if latents.shape[2] != upsampled_length: - latents = F.interpolate(latents, size=upsampled_length, mode='nearest') - - noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) - from prismaudio_core.inference.sampling import sample - decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) - - if self.pretransform is not None: - if self.pretransform.enable_grad: - decoded = self.pretransform.decode(decoded) - else: - with torch.no_grad(): - decoded = self.pretransform.decode(decoded) - - return decoded - -# AE factories - -def create_encoder_from_config(encoder_config: Dict[str, Any]): - encoder_type = encoder_config.get("type", None) - assert encoder_type is not None, "Encoder type must be specified" - - if encoder_type == "oobleck": - encoder = OobleckEncoder( - **encoder_config["config"] - ) - - elif encoder_type == "seanet": - from encodec.modules import SEANetEncoder - seanet_encoder_config = encoder_config["config"] - - #SEANet encoder expects strides in reverse order - seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) - encoder = SEANetEncoder( - **seanet_encoder_config - ) - elif encoder_type == "dac": - dac_config = encoder_config["config"] - - encoder = DACEncoderWrapper(**dac_config) - elif encoder_type == "local_attn": - from .local_attention import TransformerEncoder1D - - local_attn_config = encoder_config["config"] - - encoder = TransformerEncoder1D( - **local_attn_config - ) - else: - raise ValueError(f"Unknown encoder type {encoder_type}") - - requires_grad = encoder_config.get("requires_grad", True) - if not requires_grad: - for param in encoder.parameters(): - param.requires_grad = False - - return encoder - -def create_decoder_from_config(decoder_config: Dict[str, Any]): - decoder_type = decoder_config.get("type", None) - assert decoder_type is not None, "Decoder type must be specified" - - if decoder_type == "oobleck": - decoder = OobleckDecoder( - **decoder_config["config"] - ) - elif decoder_type == "seanet": - from encodec.modules import SEANetDecoder - - decoder = SEANetDecoder( - **decoder_config["config"] - ) - elif decoder_type == "dac": - dac_config = decoder_config["config"] - - decoder = DACDecoderWrapper(**dac_config) - elif decoder_type == "local_attn": - from .local_attention import TransformerDecoder1D - - local_attn_config = decoder_config["config"] - - decoder = TransformerDecoder1D( - **local_attn_config - ) - else: - raise ValueError(f"Unknown decoder type {decoder_type}") - - requires_grad = decoder_config.get("requires_grad", True) - if not requires_grad: - for param in decoder.parameters(): - param.requires_grad = False - - return decoder - -def create_autoencoder_from_config(config: Dict[str, Any]): - - ae_config = config["model"] - - encoder = create_encoder_from_config(ae_config["encoder"]) - decoder = create_decoder_from_config(ae_config["decoder"]) - - bottleneck = ae_config.get("bottleneck", None) - - latent_dim = ae_config.get("latent_dim", None) - assert latent_dim is not None, "latent_dim must be specified in model config" - downsampling_ratio = ae_config.get("downsampling_ratio", None) - assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" - io_channels = ae_config.get("io_channels", None) - assert io_channels is not None, "io_channels must be specified in model config" - sample_rate = config.get("sample_rate", None) - assert sample_rate is not None, "sample_rate must be specified in model config" - - in_channels = ae_config.get("in_channels", None) - out_channels = ae_config.get("out_channels", None) - - pretransform = ae_config.get("pretransform", None) - - if pretransform is not None: - pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) - - if bottleneck is not None: - bottleneck = _lazy_create_bottleneck_from_config(bottleneck) - - soft_clip = ae_config["decoder"].get("soft_clip", False) - - return AudioAutoencoder( - encoder, - decoder, - io_channels=io_channels, - latent_dim=latent_dim, - downsampling_ratio=downsampling_ratio, - sample_rate=sample_rate, - bottleneck=bottleneck, - pretransform=pretransform, - in_channels=in_channels, - out_channels=out_channels, - soft_clip=soft_clip - ) - -def create_diffAE_from_config(config: Dict[str, Any]): - - diffae_config = config["model"] - - if "encoder" in diffae_config: - encoder = create_encoder_from_config(diffae_config["encoder"]) - else: - encoder = None - - if "decoder" in diffae_config: - decoder = create_decoder_from_config(diffae_config["decoder"]) - else: - decoder = None - - diffusion_model_type = diffae_config["diffusion"]["type"] - - if diffusion_model_type == "DAU1d": - diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) - elif diffusion_model_type == "adp_1d": - diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) - elif diffusion_model_type == "dit": - diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) - - latent_dim = diffae_config.get("latent_dim", None) - assert latent_dim is not None, "latent_dim must be specified in model config" - downsampling_ratio = diffae_config.get("downsampling_ratio", None) - assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" - io_channels = diffae_config.get("io_channels", None) - assert io_channels is not None, "io_channels must be specified in model config" - sample_rate = config.get("sample_rate", None) - assert sample_rate is not None, "sample_rate must be specified in model config" - - bottleneck = diffae_config.get("bottleneck", None) - - pretransform = diffae_config.get("pretransform", None) - - if pretransform is not None: - pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate) - - if bottleneck is not None: - bottleneck = _lazy_create_bottleneck_from_config(bottleneck) - - diffusion_downsampling_ratio = None - - if diffusion_model_type == "DAU1d": - diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) - elif diffusion_model_type == "adp_1d": - diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) - elif diffusion_model_type == "dit": - diffusion_downsampling_ratio = 1 - - return DiffusionAutoencoder( - encoder=encoder, - decoder=decoder, - diffusion=diffusion, - io_channels=io_channels, - sample_rate=sample_rate, - latent_dim=latent_dim, - downsampling_ratio=downsampling_ratio, - diffusion_downsampling_ratio=diffusion_downsampling_ratio, - bottleneck=bottleneck, - pretransform=pretransform - ) diff --git a/prismaudio_core/models/blocks.py b/prismaudio_core/models/blocks.py deleted file mode 100644 index dfc0466..0000000 --- a/prismaudio_core/models/blocks.py +++ /dev/null @@ -1,331 +0,0 @@ -from functools import reduce -import math -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -from torch.backends.cuda import sdp_kernel -from packaging import version - -from dac.nn.layers import Snake1d - -class ResidualBlock(nn.Module): - def __init__(self, main, skip=None): - super().__init__() - self.main = nn.Sequential(*main) - self.skip = skip if skip else nn.Identity() - - def forward(self, input): - return self.main(input) + self.skip(input) - -class ResConvBlock(ResidualBlock): - def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): - skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) - super().__init__([ - nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), - nn.GroupNorm(1, c_mid), - Snake1d(c_mid) if use_snake else nn.GELU(), - nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), - nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), - (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), - ], skip) - -class SelfAttention1d(nn.Module): - def __init__(self, c_in, n_head=1, dropout_rate=0.): - super().__init__() - assert c_in % n_head == 0 - self.norm = nn.GroupNorm(1, c_in) - self.n_head = n_head - self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) - self.out_proj = nn.Conv1d(c_in, c_in, 1) - self.dropout = nn.Dropout(dropout_rate, inplace=True) - - self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') - - if not self.use_flash: - return - - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) - - if device_properties.major == 8 and device_properties.minor == 0: - # Use flash attention for A100 GPUs - self.sdp_kernel_config = (True, False, False) - else: - # Don't use flash attention for other GPUs - self.sdp_kernel_config = (False, True, True) - - def forward(self, input): - n, c, s = input.shape - qkv = self.qkv_proj(self.norm(input)) - qkv = qkv.view( - [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) - q, k, v = qkv.chunk(3, dim=1) - scale = k.shape[3]**-0.25 - - if self.use_flash: - with sdp_kernel(*self.sdp_kernel_config): - y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) - else: - att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) - y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) - - - return input + self.dropout(self.out_proj(y)) - -class SkipBlock(nn.Module): - def __init__(self, *main): - super().__init__() - self.main = nn.Sequential(*main) - - def forward(self, input): - return torch.cat([self.main(input), input], dim=1) - -class FourierFeatures(nn.Module): - def __init__(self, in_features, out_features, std=1.): - super().__init__() - assert out_features % 2 == 0 - self.weight = nn.Parameter(torch.randn( - [out_features // 2, in_features]) * std) - - def forward(self, input): - f = 2 * math.pi * input @ self.weight.T - return torch.cat([f.cos(), f.sin()], dim=-1) - -def expand_to_planes(input, shape): - return input[..., None].repeat([1, 1, shape[2]]) - -_kernels = { - 'linear': - [1 / 8, 3 / 8, 3 / 8, 1 / 8], - 'cubic': - [-0.01171875, -0.03515625, 0.11328125, 0.43359375, - 0.43359375, 0.11328125, -0.03515625, -0.01171875], - 'lanczos3': - [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, - -0.066637322306633, 0.13550527393817902, 0.44638532400131226, - 0.44638532400131226, 0.13550527393817902, -0.066637322306633, - -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] -} - -class Downsample1d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer('kernel', kernel_1d) - self.channels_last = channels_last - - def forward(self, x): - if self.channels_last: - x = x.permute(0, 2, 1) - x = F.pad(x, (self.pad,) * 2, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - x = F.conv1d(x, weight, stride=2) - if self.channels_last: - x = x.permute(0, 2, 1) - return x - - -class Upsample1d(nn.Module): - def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): - super().__init__() - self.pad_mode = pad_mode - kernel_1d = torch.tensor(_kernels[kernel]) * 2 - self.pad = kernel_1d.shape[0] // 2 - 1 - self.register_buffer('kernel', kernel_1d) - self.channels_last = channels_last - - def forward(self, x): - if self.channels_last: - x = x.permute(0, 2, 1) - x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) - weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) - indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) - x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) - if self.channels_last: - x = x.permute(0, 2, 1) - return x - -def Downsample1d_2( - in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 -) -> nn.Module: - assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" - - return nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * kernel_multiplier + 1, - stride=factor, - padding=factor * (kernel_multiplier // 2), - ) - - -def Upsample1d_2( - in_channels: int, out_channels: int, factor: int, use_nearest: bool = False -) -> nn.Module: - - if factor == 1: - return nn.Conv1d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 - ) - - if use_nearest: - return nn.Sequential( - nn.Upsample(scale_factor=factor, mode="nearest"), - nn.Conv1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - ), - ) - else: - return nn.ConvTranspose1d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=factor * 2, - stride=factor, - padding=factor // 2 + factor % 2, - output_padding=factor % 2, - ) - -def zero_init(layer): - nn.init.zeros_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - return layer - -class AdaRMSNorm(nn.Module): - def __init__(self, features, cond_features, eps=1e-6): - super().__init__() - self.eps = eps - self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) - - def extra_repr(self): - return f"eps={self.eps}," - - def forward(self, x, cond): - return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) - -def normalize(x, eps=1e-4): - dim = list(range(1, x.ndim)) - n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) - alpha = np.sqrt(n.numel() / x.numel()) - return x / torch.add(eps, n, alpha=alpha) - -class ForcedWNConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1): - super().__init__() - self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) - - def forward(self, x): - if self.training: - with torch.no_grad(): - self.weight.copy_(normalize(self.weight)) - - fan_in = self.weight[0].numel() - - w = normalize(self.weight) / math.sqrt(fan_in) - - return F.conv1d(x, w, padding='same') - -# Kernels - -use_compile = True - -def compile(function, *args, **kwargs): - if not use_compile: - return function - try: - return torch.compile(function, *args, **kwargs) - except RuntimeError: - return function - - -@compile -def linear_geglu(x, weight, bias=None): - x = x @ weight.mT - if bias is not None: - x = x + bias - x, gate = x.chunk(2, dim=-1) - return x * F.gelu(gate) - - -@compile -def rms_norm(x, scale, eps): - dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) - mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) - scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) - return x * scale.to(x.dtype) - -# Layers - -class LinearGEGLU(nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super().__init__(in_features, out_features * 2, bias=bias) - self.out_features = out_features - - def forward(self, x): - return linear_geglu(x, self.weight, self.bias) - - -class RMSNorm(nn.Module): - def __init__(self, shape, fix_scale = False, eps=1e-6): - super().__init__() - self.eps = eps - - if fix_scale: - self.register_buffer("scale", torch.ones(shape)) - else: - self.scale = nn.Parameter(torch.ones(shape)) - - def extra_repr(self): - return f"shape={tuple(self.scale.shape)}, eps={self.eps}" - - def forward(self, x): - return rms_norm(x, self.scale, self.eps) - -def snake_beta(x, alpha, beta): - return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) - -# try: -# snake_beta = torch.compile(snake_beta) -# except RuntimeError: -# pass - -# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license -# License available in LICENSES/LICENSE_NVIDIA.txt -class SnakeBeta(nn.Module): - - def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): - super(SnakeBeta, self).__init__() - self.in_features = in_features - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) - self.beta = nn.Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = nn.Parameter(torch.ones(in_features) * alpha) - self.beta = nn.Parameter(torch.ones(in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - self.beta.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - beta = self.beta.unsqueeze(0).unsqueeze(-1) - if self.alpha_logscale: - alpha = torch.exp(alpha) - beta = torch.exp(beta) - x = snake_beta(x, alpha, beta) - - return x \ No newline at end of file diff --git a/prismaudio_core/models/bottleneck.py b/prismaudio_core/models/bottleneck.py deleted file mode 100644 index 5e81cab..0000000 --- a/prismaudio_core/models/bottleneck.py +++ /dev/null @@ -1,355 +0,0 @@ -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -from einops import rearrange -from vector_quantize_pytorch import ResidualVQ, FSQ -from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ - -class Bottleneck(nn.Module): - def __init__(self, is_discrete: bool = False): - super().__init__() - - self.is_discrete = is_discrete - - def encode(self, x, return_info=False, **kwargs): - raise NotImplementedError - - def decode(self, x): - raise NotImplementedError - -class DiscreteBottleneck(Bottleneck): - def __init__(self, num_quantizers, codebook_size, tokens_id): - super().__init__(is_discrete=True) - - self.num_quantizers = num_quantizers - self.codebook_size = codebook_size - self.tokens_id = tokens_id - - def decode_tokens(self, codes, **kwargs): - raise NotImplementedError - -class TanhBottleneck(Bottleneck): - def __init__(self): - super().__init__(is_discrete=False) - self.tanh = nn.Tanh() - - def encode(self, x, return_info=False): - info = {} - - x = torch.tanh(x) - - if return_info: - return x, info - else: - return x - - def decode(self, x): - return x - -def vae_sample(mean, scale): - stdev = nn.functional.softplus(scale) + 1e-4 - var = stdev * stdev - logvar = torch.log(var) - latents = torch.randn_like(mean) * stdev + mean - - kl = (mean * mean + var - logvar - 1).sum(1).mean() - - return latents, kl - -class VAEBottleneck(Bottleneck): - def __init__(self): - super().__init__(is_discrete=False) - - def encode(self, x, return_info=False, **kwargs): - info = {} - - mean, scale = x.chunk(2, dim=1) - - x, kl = vae_sample(mean, scale) - - info["kl"] = kl - - if return_info: - return x, info - else: - return x - - def decode(self, x): - return x - -def compute_mean_kernel(x, y): - kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] - return torch.exp(-kernel_input).mean() - -def compute_mmd(latents): - latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) - noise = torch.randn_like(latents_reshaped) - - latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) - noise_kernel = compute_mean_kernel(noise, noise) - latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) - - mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel - return mmd.mean() - -class WassersteinBottleneck(Bottleneck): - def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): - super().__init__(is_discrete=False) - - self.noise_augment_dim = noise_augment_dim - self.bypass_mmd = bypass_mmd - - def encode(self, x, return_info=False): - info = {} - - if self.training and return_info: - if self.bypass_mmd: - mmd = torch.tensor(0.0) - else: - mmd = compute_mmd(x) - - info["mmd"] = mmd - - if return_info: - return x, info - - return x - - def decode(self, x): - - if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) - x = torch.cat([x, noise], dim=1) - - return x - -class L2Bottleneck(Bottleneck): - def __init__(self): - super().__init__(is_discrete=False) - - def encode(self, x, return_info=False): - info = {} - - x = F.normalize(x, dim=1) - - if return_info: - return x, info - else: - return x - - def decode(self, x): - return F.normalize(x, dim=1) - -class RVQBottleneck(DiscreteBottleneck): - def __init__(self, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") - self.quantizer = ResidualVQ(**quantizer_kwargs) - self.num_quantizers = quantizer_kwargs["num_quantizers"] - - def encode(self, x, return_info=False, **kwargs): - info = {} - - x = rearrange(x, "b c n -> b n c") - x, indices, loss = self.quantizer(x) - x = rearrange(x, "b n c -> b c n") - - info["quantizer_indices"] = indices - info["quantizer_loss"] = loss.mean() - - if return_info: - return x, info - else: - return x - - def decode(self, x): - return x - - def decode_tokens(self, codes, **kwargs): - latents = self.quantizer.get_outputs_from_indices(codes) - - return self.decode(latents, **kwargs) - -class RVQVAEBottleneck(DiscreteBottleneck): - def __init__(self, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") - self.quantizer = ResidualVQ(**quantizer_kwargs) - self.num_quantizers = quantizer_kwargs["num_quantizers"] - - def encode(self, x, return_info=False): - info = {} - - x, kl = vae_sample(*x.chunk(2, dim=1)) - - info["kl"] = kl - - x = rearrange(x, "b c n -> b n c") - x, indices, loss = self.quantizer(x) - x = rearrange(x, "b n c -> b c n") - - info["quantizer_indices"] = indices - info["quantizer_loss"] = loss.mean() - - if return_info: - return x, info - else: - return x - - def decode(self, x): - return x - - def decode_tokens(self, codes, **kwargs): - latents = self.quantizer.get_outputs_from_indices(codes) - - return self.decode(latents, **kwargs) - -class DACRVQBottleneck(DiscreteBottleneck): - def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") - self.quantizer = DACResidualVQ(**quantizer_kwargs) - self.num_quantizers = quantizer_kwargs["n_codebooks"] - self.quantize_on_decode = quantize_on_decode - self.noise_augment_dim = noise_augment_dim - - def encode(self, x, return_info=False, **kwargs): - info = {} - - info["pre_quantizer"] = x - - if self.quantize_on_decode: - return x, info if return_info else x - - z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) - - output = { - "z": z, - "codes": codes, - "latents": latents, - "vq/commitment_loss": commitment_loss, - "vq/codebook_loss": codebook_loss, - } - - output["vq/commitment_loss"] /= self.num_quantizers - output["vq/codebook_loss"] /= self.num_quantizers - - info.update(output) - - if return_info: - return output["z"], info - - return output["z"] - - def decode(self, x): - - if self.quantize_on_decode: - x = self.quantizer(x)[0] - - if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) - x = torch.cat([x, noise], dim=1) - - return x - - def decode_tokens(self, codes, **kwargs): - latents, _, _ = self.quantizer.from_codes(codes) - - return self.decode(latents, **kwargs) - -class DACRVQVAEBottleneck(DiscreteBottleneck): - def __init__(self, quantize_on_decode=False, **quantizer_kwargs): - super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") - self.quantizer = DACResidualVQ(**quantizer_kwargs) - self.num_quantizers = quantizer_kwargs["n_codebooks"] - self.quantize_on_decode = quantize_on_decode - - def encode(self, x, return_info=False, n_quantizers: int = None): - info = {} - - mean, scale = x.chunk(2, dim=1) - - x, kl = vae_sample(mean, scale) - - info["pre_quantizer"] = x - info["kl"] = kl - - if self.quantize_on_decode: - return x, info if return_info else x - - z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) - - output = { - "z": z, - "codes": codes, - "latents": latents, - "vq/commitment_loss": commitment_loss, - "vq/codebook_loss": codebook_loss, - } - - output["vq/commitment_loss"] /= self.num_quantizers - output["vq/codebook_loss"] /= self.num_quantizers - - info.update(output) - - if return_info: - return output["z"], info - - return output["z"] - - def decode(self, x): - - if self.quantize_on_decode: - x = self.quantizer(x)[0] - - return x - - def decode_tokens(self, codes, **kwargs): - latents, _, _ = self.quantizer.from_codes(codes) - - return self.decode(latents, **kwargs) - -class FSQBottleneck(DiscreteBottleneck): - def __init__(self, noise_augment_dim=0, **kwargs): - super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") - - self.noise_augment_dim = noise_augment_dim - - self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) - - def encode(self, x, return_info=False): - info = {} - - orig_dtype = x.dtype - x = x.float() - - x = rearrange(x, "b c n -> b n c") - x, indices = self.quantizer(x) - x = rearrange(x, "b n c -> b c n") - - x = x.to(orig_dtype) - - # Reorder indices to match the expected format - indices = rearrange(indices, "b n q -> b q n") - - info["quantizer_indices"] = indices - - if return_info: - return x, info - else: - return x - - def decode(self, x): - - if self.noise_augment_dim > 0: - noise = torch.randn(x.shape[0], self.noise_augment_dim, - x.shape[-1]).type_as(x) - x = torch.cat([x, noise], dim=1) - - return x - - def decode_tokens(self, tokens, **kwargs): - latents = self.quantizer.indices_to_codes(tokens) - - return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/prismaudio_core/models/conditioners.py b/prismaudio_core/models/conditioners.py deleted file mode 100644 index 3351f47..0000000 --- a/prismaudio_core/models/conditioners.py +++ /dev/null @@ -1,1090 +0,0 @@ -#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py - -import torch -import logging, warnings -import string -import typing as tp -import gc -from typing import Literal, Optional -import os -from .adp import NumberEmbedder -from .pretransforms import Pretransform -from .utils import load_ckpt_state_dict - - -# Stub for training utility - only needed for load_state_dict, not inference -def copy_state_dict(model, state_dict): - """Stub replacement for PrismAudio.training.utils.copy_state_dict""" - model.load_state_dict(state_dict, strict=False) - - -def set_audio_channels(audio, target_channels): - """Stub replacement for PrismAudio.inference.utils.set_audio_channels""" - if audio.shape[1] == target_channels: - return audio - if target_channels == 1: - return audio.mean(dim=1, keepdim=True) - if target_channels == 2 and audio.shape[1] == 1: - return audio.repeat(1, 2, 1) - raise ValueError(f"Cannot convert {audio.shape[1]} channels to {target_channels}") -import numpy as np -from einops import rearrange -from transformers import AutoProcessor, AutoModel -from torch import nn -import torch.nn.functional as F -from .mmmodules.model.low_level import ConvMLP, MLP -from torch.nn.utils.rnn import pad_sequence - -class Conditioner(nn.Module): - def __init__( - self, - dim: int, - output_dim: int, - project_out: bool = False - ): - - super().__init__() - - self.dim = dim - self.output_dim = output_dim - self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() - - def forward(self, x: tp.Any) -> tp.Any: - raise NotImplementedError() - -class Cond_MLP(Conditioner): - def __init__(self, dim, output_dim, dropout = 0.0): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - self.dropout = dropout - def forward(self, x, device: tp.Any = "cuda"): - x = pad_sequence(x, batch_first=True).to(device) - # x = torch.stack(x, dim=0).to(device) - - if self.dropout > 0.0: - if self.training: - null_embed = torch.zeros_like(x, device=device) - dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) - x = torch.where(dropout_mask, null_embed, x) - elif x.shape[0] < 16: # default test batch size=1 - null_embed = torch.zeros_like(x, device=device) - x = torch.cat([x, null_embed], dim=0) - - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Global_MLP(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - x = x.mean(dim=1) - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Cond_MLP_1(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim), - nn.SiLU(), - MLP(output_dim, output_dim * 4), - ) - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Cond_MLP_Global(Conditioner): - def __init__(self, dim, output_dim, dropout = 0.0): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - self.global_embedder = nn.Sequential( - nn.Linear(output_dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - self.dropout = dropout - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - if self.dropout > 0 and self.training: - null_embed = torch.zeros_like(x, device=device) - dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) - x = torch.where(dropout_mask, null_embed, x) - x = self.embedder(x) # B x 117 x C - global_x = self.global_embedder(x[:,0,:]) - return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] - -class Cond_MLP_Global_1(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim), - nn.SiLU(), - MLP(output_dim, output_dim * 4), - ) - self.global_embedder = nn.Sequential( - nn.Linear(dim, output_dim), - MLP(output_dim, output_dim * 4), - ) - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - global_x = self.global_embedder(x.mean(dim=1)) - return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] - -class Cond_MLP_Global_2(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - self.global_embedder = nn.Sequential( - nn.Linear(output_dim, output_dim, bias=False), - ) - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - global_x = self.global_embedder(x.mean(dim=1)) - return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] - -class Sync_MLP(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim, bias=False), - nn.SiLU(), - nn.Linear(output_dim, output_dim, bias=False) - ) - self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, dim))) - nn.init.constant_(self.sync_pos_emb, 0) - def forward(self, x, device: tp.Any = "cuda"): - sync_f = torch.stack(x, dim=0).to(device) - bs, length, dim = sync_f.shape - #print(sync_f.shape,flush=True) - # B * num_segments (24) * 8 * 768 - num_sync_segments = length // 8 - sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb - sync_f = sync_f.flatten(1, 2) # (B, VN, D) - x = self.embedder(sync_f) # B x 117 x C - x = x.transpose(1,2) - x = F.interpolate(x, ((int)(194*sync_f.shape[1]/216), ), mode='linear', align_corners=False) - x = x.transpose(1,2) - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Cond_ConvMLP(Conditioner): - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential( - nn.Linear(dim, output_dim), - nn.SiLU(), - ConvMLP(output_dim, output_dim * 4, kernel_size=1, padding=0), - ) - def forward(self, x, device: tp.Any = "cuda"): - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Video_Global(Conditioner): - """ Transform the video feat encoder""" - - def __init__(self, dim, output_dim, global_dim=1536): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) - self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim)) - - def forward(self, x, device: tp.Any = "cuda"): - if not isinstance(x[0], torch.Tensor): - video_feats = [] - for path in x: - if '.npy' in path: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - elif '.pth' in path: - data = torch.load(path) - video_feats.append(data['metaclip_features'].to(device)) - else: - video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) - x = torch.stack(video_feats, dim=0).to(device) - else: - # Revise the shape here: - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - global_x = self.global_proj(x.mean(dim=1)) - return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] - -class Video_Sync(Conditioner): - """ Transform the video feat encoder""" - - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) - - def forward(self, x, device: tp.Any = "cuda"): - if not isinstance(x[0], torch.Tensor): - video_feats = [] - for path in x: - if '.npy' in path: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - elif '.pth' in path: - video_feats.append(torch.load(path)['sync_features'].to(device)) - else: - video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) - x = torch.stack(video_feats, dim=0).to(device) - else: - # Revise the shape here: - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class Text_Linear(Conditioner): - """ Transform the video feat encoder""" - - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) - - def forward(self, x, device: tp.Any = "cuda"): - if not isinstance(x[0], torch.Tensor): - video_feats = [] - for path in x: - if '.npy' in path: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - elif '.pth' in path: - video_feats.append(torch.load(path)['metaclip_text_features'].to(device)) - else: - video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) - x = torch.stack(video_feats, dim=0).to(device) - else: - # Revise the shape here: - x = torch.stack(x, dim=0).to(device) - - x = self.embedder(x) # B x 117 x C - return [x, torch.ones(x.shape[0], 1).to(device)] - -class mm_unchang(Conditioner): - """ Transform the video feat encoder""" - - def __init__(self, dim, output_dim): - super().__init__(dim, output_dim) - - def forward(self, x, device: tp.Any = "cuda"): - if not isinstance(x[0], torch.Tensor): - video_feats = [] - for path in x: - if '.npy' in path: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - elif '.pth' in path: - video_feats.append(torch.load(path)['metaclip_features'].to(device)) - else: - video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) - x = torch.stack(video_feats, dim=0).to(device) - else: - # Revise the shape here: - x = torch.stack(x, dim=0).to(device) - return [x] - -class CLIPConditioner(Conditioner): - - CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"] - - CLIP_MODEL_DIMS = { - "metaclip-base": 512, - "metaclip-b16": 512, - "metaclip-large": 768, - "metaclip-huge": 1024, - } - - def __init__( - self, - dim: int, - output_dim: int, - clip_model_name: str = "metaclip-huge", - enable_grad: bool = False, - project_out: bool = False - ): - assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}" - super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out) - - self.enable_grad = enable_grad - model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16) - - - - if self.enable_grad: - self.model = model - else: - self.__dict__["model"] = model - - - def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.model.to(device) - self.proj_out.to(device) - - self.model.eval() - if not isinstance(images[0], torch.Tensor): - video_feats = [] - for path in images: - if '.npy' in path: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - else: - video_feats.append(torch.from_numpy(np.load(path)).to(device)) - images = torch.stack(video_feats, dim=0).to(device) - else: - images = torch.stack(images, dim=0).to(device) - bsz, t, c, h, w = images.shape - # 使用 rearrange 进行维度合并 - images = rearrange(images, 'b t c h w -> (b t) c h w') - with torch.set_grad_enabled(self.enable_grad): - image_features = self.model.get_image_features(images) - image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t) - image_features = self.proj_out(image_features) - - - return [image_features, torch.ones(image_features.shape[0], 1).to(device)] - -class IntConditioner(Conditioner): - def __init__(self, - output_dim: int, - min_val: int=0, - max_val: int=512 - ): - super().__init__(output_dim, output_dim) - - self.min_val = min_val - self.max_val = max_val - self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) - - def forward(self, ints: tp.List[int], device=None) -> tp.Any: - - #self.int_embedder.to(device) - - ints = torch.tensor(ints).to(device) - ints = ints.clamp(self.min_val, self.max_val) - - int_embeds = self.int_embedder(ints).unsqueeze(1) - - return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] - -class NumberConditioner(Conditioner): - ''' - Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings - ''' - def __init__(self, - output_dim: int, - min_val: float=0, - max_val: float=1 - ): - super().__init__(output_dim, output_dim) - - self.min_val = min_val - self.max_val = max_val - - self.embedder = NumberEmbedder(features=output_dim) - - def forward(self, floats: tp.List[float], device=None) -> tp.Any: - - # Cast the inputs to floats - floats = [float(x) for x in floats] - - floats = torch.tensor(floats).to(device) - - floats = floats.clamp(self.min_val, self.max_val) - - normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) - - # Cast floats to same type as embedder - embedder_dtype = next(self.embedder.parameters()).dtype - normalized_floats = normalized_floats.to(embedder_dtype) - - float_embeds = self.embedder(normalized_floats).unsqueeze(1) - - return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] - -class CLAPTextConditioner(Conditioner): - def __init__(self, - output_dim: int, - clap_ckpt_path, - use_text_features = False, - feature_layer_ix: int = -1, - audio_model_type="HTSAT-base", - enable_fusion=True, - project_out: bool = False, - finetune: bool = False): - super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) - - self.use_text_features = use_text_features - self.feature_layer_ix = feature_layer_ix - self.finetune = finetune - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - import laion_clap - from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict - - model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') - - if self.finetune: - self.model = model - else: - self.__dict__["model"] = model - - state_dict = clap_load_state_dict(clap_ckpt_path) - self.model.model.load_state_dict(state_dict, strict=False) - - if self.finetune: - self.model.model.text_branch.requires_grad_(True) - self.model.model.text_branch.train() - else: - self.model.model.text_branch.requires_grad_(False) - self.model.model.text_branch.eval() - - finally: - logging.disable(previous_level) - - del self.model.model.audio_branch - - gc.collect() - torch.cuda.empty_cache() - - def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): - prompt_tokens = self.model.tokenizer(prompts) - attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) - prompt_features = self.model.model.text_branch( - input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), - attention_mask=attention_mask, - output_hidden_states=True - )["hidden_states"][layer_ix] - - return prompt_features, attention_mask - - def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: - self.model.to(device) - - if self.use_text_features: - if len(texts) == 1: - text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) - text_features = text_features[:1, ...] - text_attention_mask = text_attention_mask[:1, ...] - else: - text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) - return [self.proj_out(text_features), text_attention_mask] - - # Fix for CLAP bug when only one text is passed - if len(texts) == 1: - text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] - else: - text_embedding = self.model.get_text_embedding(texts, use_tensor=True) - - text_embedding = text_embedding.unsqueeze(1).to(device) - - return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] - -class CLAPAudioConditioner(Conditioner): - def __init__(self, - output_dim: int, - clap_ckpt_path, - audio_model_type="HTSAT-base", - enable_fusion=True, - project_out: bool = False): - super().__init__(512, output_dim, project_out=project_out) - - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - import laion_clap - from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict - - model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') - - self.model = model - - state_dict = clap_load_state_dict(clap_ckpt_path) - self.model.model.load_state_dict(state_dict, strict=False) - - self.model.model.audio_branch.requires_grad_(False) - self.model.model.audio_branch.eval() - - finally: - logging.disable(previous_level) - - del self.model.model.text_branch - - gc.collect() - torch.cuda.empty_cache() - - def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: - - self.model.to(device) - - if isinstance(audios, list) or isinstance(audios, tuple): - audios = torch.cat(audios, dim=0) - - # Convert to mono - mono_audios = audios.mean(dim=1) - - with torch.cuda.amp.autocast(enabled=False): - audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) - - audio_embedding = audio_embedding.unsqueeze(1).to(device) - - return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] - -class T5Conditioner(Conditioner): - - T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", - "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", - "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"] - - T5_MODEL_DIMS = { - "t5-small": 512, - "t5-base": 768, - "t5-large": 1024, - "t5-3b": 1024, - "t5-11b": 1024, - "t5-v1_1-xl": 2048, - "google/t5-v1_1-xxl": 4096, - "google/flan-t5-small": 512, - "google/flan-t5-base": 768, - "google/flan-t5-large": 1024, - "google/flan-t5-3b": 1024, - "google/flan-t5-11b": 1024, - "google/flan-t5-xl": 2048, - "google/flan-t5-xxl": 4096, - } - - def __init__( - self, - output_dim: int, - t5_model_name: str = "t5-base", - max_length: str = 77, - enable_grad: bool = False, - project_out: bool = False - ): - assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" - super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) - - from transformers import T5EncoderModel, AutoTokenizer - - self.max_length = max_length - self.enable_grad = enable_grad - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) - # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) - self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name)) - model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) - finally: - logging.disable(previous_level) - - if self.enable_grad: - self.model = model - else: - self.__dict__["model"] = model - - - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.model.to(device) - self.proj_out.to(device) - encoded = self.tokenizer( - texts, - truncation=True, - max_length=self.max_length, - padding="max_length", - return_tensors="pt", - ) - - input_ids = encoded["input_ids"].to(device) - attention_mask = encoded["attention_mask"].to(device).to(torch.bool) - - self.model.eval() - - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): - embeddings = self.model( - input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] - - embeddings = self.proj_out(embeddings.float()) - - embeddings = embeddings * attention_mask.unsqueeze(-1).float() - - return embeddings, attention_mask - -def patch_clip(clip_model): - # a hack to make it output last hidden states - # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 - def new_encode_text(self, text, normalize: bool = False): - cast_dtype = self.transformer.get_cast_dtype() - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) - x = self.transformer(x, attn_mask=self.attn_mask) - x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] - return F.normalize(x, dim=-1) if normalize else x - - clip_model.encode_text = new_encode_text.__get__(clip_model) - return clip_model - -class CLIPTextConditioner(Conditioner): - def __init__( - self, - output_dim: int, - max_length: str = 77, - enable_grad: bool = False, - project_out: bool = False - ): - super().__init__(1024, output_dim, project_out=project_out) - - from transformers import T5EncoderModel, AutoTokenizer - import open_clip - from open_clip import create_model_from_pretrained - - self.max_length = max_length - self.enable_grad = enable_grad - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384', - return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) - model = patch_clip(model) - self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' - finally: - logging.disable(previous_level) - - if self.enable_grad: - self.model = model - else: - self.__dict__["model"] = model - - - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.model.to(device) - self.proj_out.to(device) - - encoded = self.tokenizer( - texts - ).to(device) - - # input_ids = encoded["input_ids"].to(device) - # attention_mask = encoded["attention_mask"].to(device).to(torch.bool) - - self.model.eval() - - with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): - embeddings = self.model.encode_text( - encoded - ) - - embeddings = self.proj_out(embeddings.float()) - - # embeddings = embeddings * attention_mask.unsqueeze(-1).float() - - return embeddings, torch.ones(embeddings.shape[0], 1).to(device) - -def patch_clip(clip_model): - # a hack to make it output last hidden states - # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 - def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - text_outputs = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - last_hidden_state = text_outputs[0] - # pooled_output = text_outputs[1] - # text_features = self.text_projection(pooled_output) - - return last_hidden_state - - clip_model.get_text_features = new_get_text_features.__get__(clip_model) - return clip_model - -class MetaCLIPTextConditioner(Conditioner): - def __init__( - self, - output_dim: int, - max_length: str = 77, - enable_grad: bool = False, - project_out: bool = False - ): - super().__init__(1024, output_dim, project_out=project_out) - - from transformers import AutoModel - from transformers import AutoProcessor - - self.max_length = max_length - self.enable_grad = enable_grad - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge") - self.model = patch_clip(self.model) - self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") - finally: - logging.disable(previous_level) - - - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.model.to(device) - self.proj_out.to(device) - encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device) - - # input_ids = encoded["input_ids"].to(device) - attention_mask = encoded["attention_mask"].to(device).to(torch.bool) - - self.model.eval() - - with torch.set_grad_enabled(self.enable_grad): - embeddings = self.model.get_text_features( - **encoded - ) - - embeddings = self.proj_out(embeddings.float()) - - # embeddings = embeddings * attention_mask.unsqueeze(-1).float() - - return embeddings, torch.ones(embeddings.shape[0],1).to(device) - -class PhonemeConditioner(Conditioner): - """ - A conditioner that turns text into phonemes and embeds them using a lookup table - Only works for English text - - Args: - output_dim: the dimension of the output embeddings - max_length: the maximum number of phonemes to embed - project_out: whether to add another linear projection to the output embeddings - """ - - def __init__( - self, - output_dim: int, - max_length: int = 1024, - project_out: bool = False, - ): - super().__init__(output_dim, output_dim, project_out=project_out) - - from g2p_en import G2p - - self.max_length = max_length - - self.g2p = G2p() - - # Reserving 0 for padding, 1 for ignored - self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) - - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.phoneme_embedder.to(device) - self.proj_out.to(device) - - batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] - - phoneme_ignore = [" ", *string.punctuation] - - # Remove ignored phonemes and cut to max length - batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] - - # Convert to ids - phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] - - #Pad to match longest and make a mask tensor for the padding - longest = max([len(ids) for ids in phoneme_ids]) - phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] - - phoneme_ids = torch.tensor(phoneme_ids).to(device) - - # Convert to embeddings - phoneme_embeds = self.phoneme_embedder(phoneme_ids) - - phoneme_embeds = self.proj_out(phoneme_embeds) - - return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) - -class TokenizerLUTConditioner(Conditioner): - """ - A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary - - Args: - tokenizer_name: the name of the tokenizer from the Hugging Face transformers library - output_dim: the dimension of the output embeddings - max_length: the maximum length of the text to embed - project_out: whether to add another linear projection to the output embeddings - """ - - def __init__( - self, - tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library - output_dim: int, - max_length: int = 1024, - project_out: bool = False, - ): - super().__init__(output_dim, output_dim, project_out=project_out) - - from transformers import AutoTokenizer - - # Suppress logging from transformers - previous_level = logging.root.manager.disable - logging.disable(logging.ERROR) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - try: - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - finally: - logging.disable(previous_level) - - self.max_length = max_length - - self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) - - def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - self.proj_out.to(device) - - encoded = self.tokenizer( - texts, - truncation=True, - max_length=self.max_length, - padding="max_length", - return_tensors="pt", - ) - - input_ids = encoded["input_ids"].to(device) - attention_mask = encoded["attention_mask"].to(device).to(torch.bool) - - embeddings = self.token_embedder(input_ids) - - embeddings = self.proj_out(embeddings) - - embeddings = embeddings * attention_mask.unsqueeze(-1).float() - - return embeddings, attention_mask - -class PretransformConditioner(Conditioner): - """ - A conditioner that uses a pretransform's encoder for conditioning - - Args: - pretransform: an instantiated pretransform to use for conditioning - output_dim: the dimension of the output embeddings - """ - def __init__(self, pretransform: Pretransform, output_dim: int): - super().__init__(pretransform.encoded_channels, output_dim) - - self.pretransform = pretransform - - def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: - - self.pretransform.to(device) - self.proj_out.to(device) - - if isinstance(audio, list) or isinstance(audio, tuple): - audio = torch.cat(audio, dim=0) - - # Convert audio to pretransform input channels - audio = set_audio_channels(audio, self.pretransform.io_channels) - - latents = self.pretransform.encode(audio) - - latents = self.proj_out(latents) - - return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] - -class MultiConditioner(nn.Module): - """ - A module that applies multiple conditioners to an input dictionary based on the keys - - Args: - conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") - default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) - """ - def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): - super().__init__() - - self.conditioners = nn.ModuleDict(conditioners) - self.default_keys = default_keys - - def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: - output = {} - - for key, conditioner in self.conditioners.items(): - condition_key = key - - conditioner_inputs = [] - - for x in batch_metadata: - - if condition_key not in x: - if condition_key in self.default_keys: - condition_key = self.default_keys[condition_key] - else: - raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") - - #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list - if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: - conditioner_input = x[condition_key][0] - - else: - conditioner_input = x[condition_key] - - conditioner_inputs.append(conditioner_input) - - cond_output = conditioner(conditioner_inputs, device) - if len(cond_output) == 1: - output[key] = cond_output[0] - elif len(cond_output) == 2: - output[key] = cond_output - elif len(cond_output) == 4: - output[key] = cond_output[:2] - output[f'{key}_g'] = cond_output[2:] - - return output - -def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: - """ - Create a MultiConditioner from a conditioning config dictionary - - Args: - config: the conditioning config dictionary - device: the device to put the conditioners on - """ - conditioners = {} - cond_dim = config["cond_dim"] - - default_keys = config.get("default_keys", {}) - - for conditioner_info in config["configs"]: - id = conditioner_info["id"] - - conditioner_type = conditioner_info["type"] - - conditioner_config = {"output_dim": cond_dim} - - conditioner_config.update(conditioner_info["config"]) - if conditioner_type == "t5": - conditioners[id] = T5Conditioner(**conditioner_config) - elif conditioner_type == "clap_text": - conditioners[id] = CLAPTextConditioner(**conditioner_config) - elif conditioner_type == "clip_text": - conditioners[id] = CLIPTextConditioner(**conditioner_config) - elif conditioner_type == "metaclip_text": - conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) - elif conditioner_type == "clap_audio": - conditioners[id] = CLAPAudioConditioner(**conditioner_config) - elif conditioner_type == "cond_mlp": - conditioners[id] = Cond_MLP(**conditioner_config) - elif conditioner_type == "global_mlp": - conditioners[id] = Global_MLP(**conditioner_config) - elif conditioner_type == "sync_mlp": - conditioners[id] = Sync_MLP(**conditioner_config) - elif conditioner_type == "cond_mlp_1": - conditioners[id] = Cond_MLP_1(**conditioner_config) - elif conditioner_type == "cond_convmlp": - conditioners[id] = Cond_ConvMLP(**conditioner_config) - elif conditioner_type == "cond_mlp_global": - conditioners[id] = Cond_MLP_Global(**conditioner_config) - elif conditioner_type == "cond_mlp_global_1": - conditioners[id] = Cond_MLP_Global_1(**conditioner_config) - elif conditioner_type == "cond_mlp_global_2": - conditioners[id] = Cond_MLP_Global_2(**conditioner_config) - elif conditioner_type == "video_linear": - conditioners[id] = Video_Linear(**conditioner_config) - elif conditioner_type == "video_global": - conditioners[id] = Video_Global(**conditioner_config) - elif conditioner_type == "video_sync": - conditioners[id] = Video_Sync(**conditioner_config) - elif conditioner_type == "text_linear": - conditioners[id] = Text_Linear(**conditioner_config) - elif conditioner_type == "video_clip": - conditioners[id] = CLIPConditioner(**conditioner_config) - elif conditioner_type == "video_hiera": - conditioners[id] = VideoHieraConditioner(**conditioner_config) - elif conditioner_type == "meta_query": - try: - from .meta_queries.model import MLLMInContext - except ImportError: - raise ImportError("meta_queries module is not available. Cannot create meta_query conditioner.") - conditioners[id] = MLLMInContext(**conditioner_config) - elif conditioner_type == "int": - conditioners[id] = IntConditioner(**conditioner_config) - elif conditioner_type == "number": - conditioners[id] = NumberConditioner(**conditioner_config) - elif conditioner_type == "phoneme": - conditioners[id] = PhonemeConditioner(**conditioner_config) - elif conditioner_type == "lut": - conditioners[id] = TokenizerLUTConditioner(**conditioner_config) - elif conditioner_type == "pretransform": - sample_rate = conditioner_config.pop("sample_rate", None) - assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" - - from prismaudio_core.factory import create_pretransform_from_config - pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) - - if conditioner_config.get("pretransform_ckpt_path", None) is not None: - pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) - - conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) - elif conditioner_type == "mm_unchang": - conditioners[id] = mm_unchang(**conditioner_config) - else: - raise ValueError(f"Unknown conditioner type: {conditioner_type}") - - return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/prismaudio_core/models/diffusion.py b/prismaudio_core/models/diffusion.py deleted file mode 100644 index 8e3aee1..0000000 --- a/prismaudio_core/models/diffusion.py +++ /dev/null @@ -1,884 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F -from functools import partial -import numpy as np -import typing as tp - -from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes -from .conditioners import MultiConditioner -from .dit import DiffusionTransformer -from .pretransforms import Pretransform - -from .adp import UNetCFG1d, UNet1d - -# Lazy imports for factory functions to avoid circular imports -def _get_create_pretransform_from_config(): - from prismaudio_core.factory import create_pretransform_from_config - return create_pretransform_from_config - -def _get_create_multi_conditioner_from_conditioning_config(): - from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config - return create_multi_conditioner_from_conditioning_config - -class DiffusionModel(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x, t, **kwargs): - raise NotImplementedError() - -class DiffusionModelWrapper(nn.Module): - def __init__( - self, - model: DiffusionModel, - io_channels, - sample_size, - sample_rate, - min_input_length, - pretransform: tp.Optional[Pretransform] = None, - ): - super().__init__() - self.io_channels = io_channels - self.sample_size = sample_size - self.sample_rate = sample_rate - self.min_input_length = min_input_length - - self.model = model - - if pretransform is not None: - self.pretransform = pretransform - else: - self.pretransform = None - - def forward(self, x, t, **kwargs): - return self.model(x, t, **kwargs) - -class ConditionedDiffusionModel(nn.Module): - def __init__(self, - *args, - supports_cross_attention: bool = False, - supports_input_concat: bool = False, - supports_global_cond: bool = False, - supports_prepend_cond: bool = False, - **kwargs): - super().__init__(*args, **kwargs) - self.supports_cross_attention = supports_cross_attention - self.supports_input_concat = supports_input_concat - self.supports_global_cond = supports_global_cond - self.supports_prepend_cond = supports_prepend_cond - - def forward(self, - x: torch.Tensor, - t: torch.Tensor, - cross_attn_cond: torch.Tensor = None, - cross_attn_mask: torch.Tensor = None, - input_concat_cond: torch.Tensor = None, - global_embed: torch.Tensor = None, - prepend_cond: torch.Tensor = None, - prepend_cond_mask: torch.Tensor = None, - cfg_scale: float = 1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - **kwargs): - raise NotImplementedError() - -class ConditionedDiffusionModelWrapper(nn.Module): - """ - A diffusion model that takes in conditioning - """ - def __init__( - self, - model: ConditionedDiffusionModel, - conditioner: MultiConditioner, - io_channels, - sample_rate, - min_input_length: int, - diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", - zero_init: bool = False, - pretransform: tp.Optional[Pretransform] = None, - cross_attn_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [], - input_concat_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - add_cond_ids: tp.List[str] = [], - sync_cond_ids: tp.List[str] = [], - ): - super().__init__() - - self.model = model - self.conditioner = conditioner - self.io_channels = io_channels - self.sample_rate = sample_rate - self.diffusion_objective = diffusion_objective - self.pretransform = pretransform - self.cross_attn_cond_ids = cross_attn_cond_ids - self.global_cond_ids = global_cond_ids - self.input_concat_ids = input_concat_ids - self.prepend_cond_ids = prepend_cond_ids - self.add_cond_ids = add_cond_ids - self.sync_cond_ids = sync_cond_ids - self.min_input_length = min_input_length - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - if zero_init is True: - self.conditioner.apply(_basic_init) - self.model.model.initialize_weights() - - - def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): - cross_attention_input = None - cross_attention_masks = None - global_cond = None - input_concat_cond = None - prepend_cond = None - prepend_cond_mask = None - add_input = None - sync_input = None - - if len(self.cross_attn_cond_ids) > 0: - # Concatenate all cross-attention inputs over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - cross_attention_input = [] - cross_attention_masks = [] - - for key in self.cross_attn_cond_ids: - cross_attn_in, cross_attn_mask = conditioning_tensors[key] - - # Add sequence dimension if it's not there - if len(cross_attn_in.shape) == 2: - cross_attn_in = cross_attn_in.unsqueeze(1) - # cross_attn_mask = cross_attn_mask.unsqueeze(1) - - cross_attention_input.append(cross_attn_in) - cross_attention_masks.append(cross_attn_mask) - - cross_attention_input = torch.cat(cross_attention_input, dim=1) - cross_attention_masks = torch.cat(cross_attention_masks, dim=1) - - if len(self.add_cond_ids) > 0: - # Concatenate all cross-attention inputs over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - add_input = [] - - for key in self.add_cond_ids: - add_in = conditioning_tensors[key][0] - - # Add sequence dimension if it's not there - if len(add_in.shape) == 2: - add_in = add_in.unsqueeze(1) - # add_in = add_in.transpose(1,2) - # add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False) - # add_in = add_in.transpose(1,2) - add_input.append(add_in) - - add_input = torch.cat(add_input, dim=2) - - if len(self.sync_cond_ids) > 0: - # Concatenate all cross-attention inputs over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - sync_input = [] - - for key in self.sync_cond_ids: - sync_in = conditioning_tensors[key][0] - - # Add sequence dimension if it's not there - if len(sync_in.shape) == 2: - sync_in = sync_in.unsqueeze(1) - sync_input.append(sync_in) - - sync_input = torch.cat(sync_input, dim=2) - - if len(self.global_cond_ids) > 0: - # Concatenate all global conditioning inputs over the channel dimension - # Assumes that the global conditioning inputs are of shape (batch, channels) - global_conds = [] - for key in self.global_cond_ids: - global_cond_input = conditioning_tensors[key][0] - if len(global_cond_input.shape) == 2: - global_cond_input = global_cond_input.unsqueeze(1) - global_conds.append(global_cond_input) - - # # Concatenate over the channel dimension - # if global_conds[0].shape[-1] == 768: - # global_cond = torch.cat(global_conds, dim=-1) - # else: - # global_cond = sum(global_conds) - global_cond = sum(global_conds) - # global_cond = torch.cat(global_conds, dim=-1) - - if len(global_cond.shape) == 3: - global_cond = global_cond.squeeze(1) - - if len(self.input_concat_ids) > 0: - # Concatenate all input concat conditioning inputs over the channel dimension - # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) - input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) - - if len(self.prepend_cond_ids) > 0: - # Concatenate all prepend conditioning inputs over the sequence dimension - # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) - prepend_conds = [] - prepend_cond_masks = [] - - for key in self.prepend_cond_ids: - prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] - if len(prepend_cond_input.shape) == 2: - prepend_cond_input = prepend_cond_input.unsqueeze(1) - prepend_conds.append(prepend_cond_input) - prepend_cond_masks.append(prepend_cond_mask) - - prepend_cond = torch.cat(prepend_conds, dim=1) - prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) - - if negative: - return { - "negative_cross_attn_cond": cross_attention_input, - "negative_cross_attn_mask": cross_attention_masks, - "negative_global_cond": global_cond, - "negative_input_concat_cond": input_concat_cond - } - else: - return { - "cross_attn_cond": cross_attention_input, - "cross_attn_mask": cross_attention_masks, - "global_cond": global_cond, - "input_concat_cond": input_concat_cond, - "prepend_cond": prepend_cond, - "prepend_cond_mask": prepend_cond_mask, - "add_cond": add_input, - "sync_cond": sync_input - } - - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): - return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) - - def generate(self, *args, **kwargs): - from prismaudio_core.inference.generation import generate_diffusion_cond - return generate_diffusion_cond(self, *args, **kwargs) - -class UNetCFG1DWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): - super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) - - self.model = UNetCFG1d(*args, **kwargs) - - with torch.no_grad(): - for param in self.model.parameters(): - param *= 0.5 - - def forward(self, - x, - t, - cross_attn_cond=None, - cross_attn_mask=None, - input_concat_cond=None, - global_cond=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - prepend_cond=None, - prepend_cond_mask=None, - **kwargs): - channels_list = None - if input_concat_cond is not None: - channels_list = [input_concat_cond] - - outputs = self.model( - x, - t, - embedding=cross_attn_cond, - embedding_mask=cross_attn_mask, - features=global_cond, - channels_list=channels_list, - embedding_scale=cfg_scale, - embedding_mask_proba=cfg_dropout_prob, - batch_cfg=batch_cfg, - rescale_cfg=rescale_cfg, - negative_embedding=negative_cross_attn_cond, - negative_embedding_mask=negative_cross_attn_mask, - **kwargs) - - return outputs - -class UNet1DCondWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): - super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) - - self.model = UNet1d(*args, **kwargs) - - with torch.no_grad(): - for param in self.model.parameters(): - param *= 0.5 - - def forward(self, - x, - t, - input_concat_cond=None, - global_cond=None, - cross_attn_cond=None, - cross_attn_mask=None, - prepend_cond=None, - prepend_cond_mask=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - **kwargs): - - channels_list = None - if input_concat_cond is not None: - - # Interpolate input_concat_cond to the same length as x - if input_concat_cond.shape[2] != x.shape[2]: - input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') - - channels_list = [input_concat_cond] - - outputs = self.model( - x, - t, - features=global_cond, - channels_list=channels_list, - **kwargs) - - return outputs - -class UNet1DUncondWrapper(DiffusionModel): - def __init__( - self, - in_channels, - *args, - **kwargs - ): - super().__init__() - - self.model = UNet1d(in_channels=in_channels, *args, **kwargs) - - self.io_channels = in_channels - - with torch.no_grad(): - for param in self.model.parameters(): - param *= 0.5 - - def forward(self, x, t, **kwargs): - return self.model(x, t, **kwargs) - -class DAU1DCondWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): - super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) - - self.model = DiffusionAttnUnet1D(*args, **kwargs) - - with torch.no_grad(): - for param in self.model.parameters(): - param *= 0.5 - - def forward(self, - x, - t, - input_concat_cond=None, - cross_attn_cond=None, - cross_attn_mask=None, - global_cond=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = False, - rescale_cfg: bool = False, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - negative_global_cond=None, - negative_input_concat_cond=None, - prepend_cond=None, - **kwargs): - - return self.model(x, t, cond = input_concat_cond) - -class DiffusionAttnUnet1D(nn.Module): - def __init__( - self, - io_channels = 2, - depth=14, - n_attn_layers = 6, - channels = [128, 128, 256, 256] + [512] * 10, - cond_dim = 0, - cond_noise_aug = False, - kernel_size = 5, - learned_resample = False, - strides = [2] * 13, - conv_bias = True, - use_snake = False - ): - super().__init__() - - self.cond_noise_aug = cond_noise_aug - - self.io_channels = io_channels - - if self.cond_noise_aug: - self.rng = torch.quasirandom.SobolEngine(1, scramble=True) - - self.timestep_embed = FourierFeatures(1, 16) - - attn_layer = depth - n_attn_layers - - strides = [1] + strides - - block = nn.Identity() - - conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) - - for i in range(depth, 0, -1): - c = channels[i - 1] - stride = strides[i-1] - if stride > 2 and not learned_resample: - raise ValueError("Must have stride 2 without learned resampling") - - if i > 1: - c_prev = channels[i - 2] - add_attn = i >= attn_layer and n_attn_layers > 0 - block = SkipBlock( - Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), - conv_block(c_prev, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), - conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), - conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), - block, - conv_block(c * 2 if i != depth else c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), - conv_block(c, c, c), - SelfAttention1d( - c, c // 32) if add_attn else nn.Identity(), - conv_block(c, c, c_prev), - SelfAttention1d(c_prev, c_prev // - 32) if add_attn else nn.Identity(), - Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") - ) - else: - cond_embed_dim = 16 if not self.cond_noise_aug else 32 - block = nn.Sequential( - conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), - conv_block(c, c, c), - conv_block(c, c, c), - block, - conv_block(c * 2, c, c), - conv_block(c, c, c), - conv_block(c, c, io_channels, is_last=True), - ) - self.net = block - - with torch.no_grad(): - for param in self.net.parameters(): - param *= 0.5 - - def forward(self, x, t, cond=None, cond_aug_scale=None): - - timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) - - inputs = [x, timestep_embed] - - if cond is not None: - if cond.shape[2] != x.shape[2]: - cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) - - if self.cond_noise_aug: - # Get a random number between 0 and 1, uniformly sampled - if cond_aug_scale is None: - aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) - else: - aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) - - # Add noise to the conditioning signal - cond = cond + torch.randn_like(cond) * aug_level[:, None, None] - - # Get embedding for noise cond level, reusing timestamp_embed - aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) - - inputs.append(aug_level_embed) - - inputs.append(cond) - - outputs = self.net(torch.cat(inputs, dim=1)) - - return outputs - -class DiTWrapper(ConditionedDiffusionModel): - def __init__( - self, - *args, - **kwargs - ): - super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) - - self.model = DiffusionTransformer(*args, **kwargs) - # with torch.no_grad(): - # for param in self.model.parameters(): - # param *= 0.5 - - def forward(self, - x, - t, - cross_attn_cond=None, - cross_attn_mask=None, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - input_concat_cond=None, - negative_input_concat_cond=None, - global_cond=None, - negative_global_cond=None, - prepend_cond=None, - prepend_cond_mask=None, - cfg_scale=1.0, - cfg_dropout_prob: float = 0.0, - batch_cfg: bool = True, - rescale_cfg: bool = False, - scale_phi: float = 0.0, - **kwargs): - - assert batch_cfg, "batch_cfg must be True for DiTWrapper" - #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" - - return self.model( - x, - t, - cross_attn_cond=cross_attn_cond, - cross_attn_cond_mask=cross_attn_mask, - negative_cross_attn_cond=negative_cross_attn_cond, - negative_cross_attn_mask=negative_cross_attn_mask, - input_concat_cond=input_concat_cond, - prepend_cond=prepend_cond, - prepend_cond_mask=prepend_cond_mask, - cfg_scale=cfg_scale, - cfg_dropout_prob=cfg_dropout_prob, - scale_phi=scale_phi, - global_embed=global_cond, - **kwargs) - -class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): - """ - A diffusion model that takes in conditioning - """ - def __init__( - self, - model, - conditioner: MultiConditioner, - io_channels, - sample_rate, - min_input_length: int, - diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", - pretransform: tp.Optional[Pretransform] = None, - cross_attn_cond_ids: tp.List[str] = [], - global_cond_ids: tp.List[str] = [], - input_concat_ids: tp.List[str] = [], - prepend_cond_ids: tp.List[str] = [], - add_cond_ids: tp.List[str] = [], - mm_cond_ids: tp.List[str] = [], - ): - super().__init__() - - self.model = model - self.conditioner = conditioner - self.io_channels = io_channels - self.sample_rate = sample_rate - self.diffusion_objective = diffusion_objective - self.pretransform = pretransform - self.cross_attn_cond_ids = cross_attn_cond_ids - self.global_cond_ids = global_cond_ids - self.input_concat_ids = input_concat_ids - self.prepend_cond_ids = prepend_cond_ids - self.add_cond_ids = add_cond_ids - self.min_input_length = min_input_length - self.mm_cond_ids = mm_cond_ids - - assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper" - assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper" - assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper" - assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper" - assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper" - assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper" - assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper" - assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper" - assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper" - # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper" - - def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): - assert negative == False, "negative conditioning is not supported for MMDiTWrapper" - cross_attention_input = None - cross_attention_masks = None - global_cond = None - input_concat_cond = None - prepend_cond = None - prepend_cond_mask = None - add_input = None - inpaint_masked_input = None - t5_features = None - metaclip_global_text_features = None - clip_f = conditioning_tensors["metaclip_features"] - sync_f = conditioning_tensors["sync_features"] - text_f = conditioning_tensors["metaclip_text_features"] - if 'inpaint_masked_input' in conditioning_tensors.keys(): - inpaint_masked_input = conditioning_tensors["inpaint_masked_input"] - if 't5_features' in conditioning_tensors.keys(): - t5_features = conditioning_tensors["t5_features"] - if 'metaclip_global_text_features' in conditioning_tensors.keys(): - metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"] - return { - "clip_f": clip_f, - "sync_f": sync_f, - "text_f": text_f, - "inpaint_masked_input": inpaint_masked_input, - "t5_features": t5_features, - "metaclip_global_text_features": metaclip_global_text_features - } - - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): - return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs) - - def generate(self, *args, **kwargs): - from prismaudio_core.inference.generation import generate_diffusion_cond - return generate_diffusion_cond(self, *args, **kwargs) - -class DiTUncondWrapper(DiffusionModel): - def __init__( - self, - io_channels, - *args, - **kwargs - ): - super().__init__() - - self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs) - - self.io_channels = io_channels - - with torch.no_grad(): - for param in self.model.parameters(): - param *= 0.5 - - def forward(self, x, t, **kwargs): - return self.model(x, t, **kwargs) - -def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): - diffusion_uncond_config = config["model"] - - model_type = diffusion_uncond_config.get('type', None) - - diffusion_config = diffusion_uncond_config.get('config', {}) - - assert model_type is not None, "Must specify model type in config" - - pretransform = diffusion_uncond_config.get("pretransform", None) - - sample_size = config.get("sample_size", None) - assert sample_size is not None, "Must specify sample size in config" - - sample_rate = config.get("sample_rate", None) - assert sample_rate is not None, "Must specify sample rate in config" - - if pretransform is not None: - pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) - min_input_length = pretransform.downsampling_ratio - else: - min_input_length = 1 - - if model_type == 'DAU1d': - - model = DiffusionAttnUnet1D( - **diffusion_config - ) - - elif model_type == "adp_uncond_1d": - - model = UNet1DUncondWrapper( - **diffusion_config - ) - - elif model_type == "dit": - model = DiTUncondWrapper( - **diffusion_config - ) - - else: - raise NotImplementedError(f'Unknown model type: {model_type}') - - return DiffusionModelWrapper(model, - io_channels=model.io_channels, - sample_size=sample_size, - sample_rate=sample_rate, - pretransform=pretransform, - min_input_length=min_input_length) - -def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]): - diffusion_uncond_config = config["model"] - - - diffusion_config = diffusion_uncond_config.get('diffusion', {}) - model_type = diffusion_config.get('type', None) - model_config = diffusion_config.get("config",{}) - assert model_type is not None, "Must specify model type in config" - - pretransform = diffusion_uncond_config.get("pretransform", None) - - sample_size = config.get("sample_size", None) - assert sample_size is not None, "Must specify sample size in config" - - sample_rate = config.get("sample_rate", None) - assert sample_rate is not None, "Must specify sample rate in config" - - if pretransform is not None: - pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) - min_input_length = pretransform.downsampling_ratio - else: - min_input_length = 1 - - if model_type == 'DAU1d': - - model = DiffusionAttnUnet1D( - **model_config - ) - - elif model_type == "adp_uncond_1d": - - io_channels = model_config.get("io_channels", 64) - model = UNet1DUncondWrapper( - io_channels = io_channels, - **model_config - ) - elif model_type == "dit": - model = DiTUncondWrapper( - **model_config - ) - - else: - raise NotImplementedError(f'Unknown model type: {model_type}') - - return DiffusionModelWrapper(model, - io_channels=model.io_channels, - sample_size=sample_size, - sample_rate=sample_rate, - pretransform=pretransform, - min_input_length=min_input_length) - -def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): - - model_config = config["model"] - - model_type = config["model_type"] - - diffusion_config = model_config.get('diffusion', None) - assert diffusion_config is not None, "Must specify diffusion config" - - diffusion_model_type = diffusion_config.get('type', None) - assert diffusion_model_type is not None, "Must specify diffusion model type" - - diffusion_model_config = diffusion_config.get('config', None) - assert diffusion_model_config is not None, "Must specify diffusion model config" - - if diffusion_model_type == 'adp_cfg_1d': - diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) - elif diffusion_model_type == 'adp_1d': - diffusion_model = UNet1DCondWrapper(**diffusion_model_config) - elif diffusion_model_type == 'dit': - diffusion_model = DiTWrapper(**diffusion_model_config) - else: - raise NotImplementedError(f'Unknown diffusion model type: {diffusion_model_type}') - - io_channels = model_config.get('io_channels', None) - assert io_channels is not None, "Must specify io_channels in model config" - - sample_rate = config.get('sample_rate', None) - assert sample_rate is not None, "Must specify sample_rate in config" - - diffusion_objective = diffusion_config.get('diffusion_objective', 'v') - - conditioning_config = model_config.get('conditioning', None) - - conditioner = None - if conditioning_config is not None: - conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config) - - cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) - add_cond_ids = diffusion_config.get('add_cond_ids', []) - sync_cond_ids = diffusion_config.get('sync_cond_ids', []) - global_cond_ids = diffusion_config.get('global_cond_ids', []) - input_concat_ids = diffusion_config.get('input_concat_ids', []) - prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) - mm_cond_ids = diffusion_config.get('mm_cond_ids', []) - zero_init = diffusion_config.get('zero_init', False) - pretransform = model_config.get("pretransform", None) - - if pretransform is not None: - pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate) - min_input_length = pretransform.downsampling_ratio - else: - min_input_length = 1 - - if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": - min_input_length *= np.prod(diffusion_model_config["factors"]) - elif diffusion_model_type == "dit": - min_input_length *= diffusion_model.model.patch_size - - # Get the proper wrapper class - - extra_kwargs = {} - - if model_type == "mm_diffusion_cond": - wrapper_fn = MMConditionedDiffusionModelWrapper - extra_kwargs["diffusion_objective"] = diffusion_objective - extra_kwargs["mm_cond_ids"] = mm_cond_ids - - elif model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': - wrapper_fn = ConditionedDiffusionModelWrapper - extra_kwargs["diffusion_objective"] = diffusion_objective - - else: - raise NotImplementedError(f'Unknown model type: {model_type}') - - return wrapper_fn( - diffusion_model, - conditioner, - min_input_length=min_input_length, - sample_rate=sample_rate, - cross_attn_cond_ids=cross_attention_ids, - global_cond_ids=global_cond_ids, - input_concat_ids=input_concat_ids, - prepend_cond_ids=prepend_cond_ids, - add_cond_ids=add_cond_ids, - sync_cond_ids=sync_cond_ids, - pretransform=pretransform, - io_channels=io_channels, - zero_init=zero_init, - **extra_kwargs - ) \ No newline at end of file diff --git a/prismaudio_core/models/dit.py b/prismaudio_core/models/dit.py deleted file mode 100644 index ec282a9..0000000 --- a/prismaudio_core/models/dit.py +++ /dev/null @@ -1,539 +0,0 @@ -import typing as tp -import math -import torch -# from beartype.typing import Tuple -from einops import rearrange -from torch import nn -from torch.nn import functional as F -from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP -from .blocks import FourierFeatures -from .transformer import ContinuousTransformer -from .utils import mask_from_frac_lengths, resample - -class DiffusionTransformer(nn.Module): - def __init__(self, - io_channels=32, - patch_size=1, - embed_dim=768, - cond_token_dim=0, - project_cond_tokens=True, - global_cond_dim=0, - project_global_cond=True, - input_concat_dim=0, - prepend_cond_dim=0, - cond_ctx_dim=0, - depth=12, - num_heads=8, - transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", - global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", - timestep_cond_type: tp.Literal["global", "input_concat"] = "global", - add_token_dim=0, - sync_token_dim=0, - use_mlp=False, - use_zero_init=False, - **kwargs): - - super().__init__() - - self.cond_token_dim = cond_token_dim - - # Timestep embeddings - timestep_features_dim = 256 - # Timestep embeddings - self.timestep_cond_type = timestep_cond_type - self.timestep_features = FourierFeatures(1, timestep_features_dim) - - if timestep_cond_type == "global": - timestep_embed_dim = embed_dim - elif timestep_cond_type == "input_concat": - assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat" - input_concat_dim += timestep_embed_dim - - self.to_timestep_embed = nn.Sequential( - nn.Linear(timestep_features_dim, embed_dim, bias=True), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=True), - ) - self.use_mlp = use_mlp - if cond_token_dim > 0: - # Conditioning tokens - cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim - self.to_cond_embed = nn.Sequential( - nn.Linear(cond_token_dim, cond_embed_dim, bias=False), - nn.SiLU(), - nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) - ) - else: - cond_embed_dim = 0 - - if global_cond_dim > 0: - # Global conditioning - global_embed_dim = global_cond_dim if not project_global_cond else embed_dim - self.to_global_embed = nn.Sequential( - nn.Linear(global_cond_dim, global_embed_dim, bias=False), - nn.SiLU(), - nn.Linear(global_embed_dim, global_embed_dim, bias=False) - ) - if add_token_dim > 0: - # Conditioning tokens - add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim - self.to_add_embed = nn.Sequential( - nn.Linear(add_token_dim, add_embed_dim, bias=False), - nn.SiLU(), - nn.Linear(add_embed_dim, add_embed_dim, bias=False) - ) - else: - add_embed_dim = 0 - - if sync_token_dim > 0: - # Conditioning tokens - sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim - self.to_sync_embed = nn.Sequential( - nn.Linear(sync_token_dim, sync_embed_dim, bias=False), - nn.SiLU(), - nn.Linear(sync_embed_dim, sync_embed_dim, bias=False) - ) - else: - sync_embed_dim = 0 - - - if prepend_cond_dim > 0: - # Prepend conditioning - self.to_prepend_embed = nn.Sequential( - nn.Linear(prepend_cond_dim, embed_dim, bias=False), - nn.SiLU(), - nn.Linear(embed_dim, embed_dim, bias=False) - ) - - self.input_concat_dim = input_concat_dim - - dim_in = io_channels + self.input_concat_dim - - self.patch_size = patch_size - - # Transformer - - self.transformer_type = transformer_type - - self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) - self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) - self.global_cond_type = global_cond_type - if self.transformer_type == "continuous_transformer": - - global_dim = None - - if self.global_cond_type == "adaLN": - # The global conditioning is projected to the embed_dim already at this point - global_dim = embed_dim - - self.transformer = ContinuousTransformer( - dim=embed_dim, - depth=depth, - dim_heads=embed_dim // num_heads, - dim_in=dim_in * patch_size, - dim_out=io_channels * patch_size, - cross_attend = cond_token_dim > 0, - cond_token_dim = cond_embed_dim, - global_cond_dim=global_dim, - **kwargs - ) - else: - raise ValueError(f"Unknown transformer type: {self.transformer_type}") - - self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) - self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) - nn.init.zeros_(self.preprocess_conv.weight) - nn.init.zeros_(self.postprocess_conv.weight) - - - def initialize_weights(self): - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - # if isinstance(module, nn.Conv1d): - # if module.bias is not None: - # nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02) - nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02) - - # Zero-out output layers: - if self.global_cond_type == "adaLN": - for block in self.transformer.layers: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - nn.init.constant_(self.empty_clip_feat, 0) - nn.init.constant_(self.empty_sync_feat, 0) - - def _forward( - self, - x, - t, - mask=None, - cross_attn_cond=None, - cross_attn_cond_mask=None, - input_concat_cond=None, - global_embed=None, - prepend_cond=None, - prepend_cond_mask=None, - add_cond=None, - add_masks=None, - sync_cond=None, - return_info=False, - **kwargs): - - if cross_attn_cond is not None: - cross_attn_cond = self.to_cond_embed(cross_attn_cond) - - if global_embed is not None: - # Project the global conditioning to the embedding dimension - global_embed = self.to_global_embed(global_embed) - - prepend_inputs = None - prepend_mask = None - prepend_length = 0 - if prepend_cond is not None: - # Project the prepend conditioning to the embedding dimension - prepend_cond = self.to_prepend_embed(prepend_cond) - - prepend_inputs = prepend_cond - if prepend_cond_mask is not None: - prepend_mask = prepend_cond_mask - - if input_concat_cond is not None: - # reshape from (b, n, c) to (b, c, n) - if input_concat_cond.shape[1] != x.shape[1]: - input_concat_cond = input_concat_cond.transpose(1,2) - # Interpolate input_concat_cond to the same length as x - # if input_concat_cond.shape[1] != x.shape[2]: - # input_concat_cond = input_concat_cond.transpose(1,2) - input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') - # input_concat_cond = input_concat_cond.transpose(1,2) - # if len(global_embed.shape) == 2: - # global_embed = global_embed.unsqueeze(1) - # global_embed = global_embed + input_concat_cond - x = torch.cat([x, input_concat_cond], dim=1) - - # Get the batch of timestep embeddings - timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) - # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists - if self.timestep_cond_type == "global": - if global_embed is not None: - if len(global_embed.shape) == 3: - timestep_embed = timestep_embed.unsqueeze(1) - global_embed = global_embed + timestep_embed - else: - global_embed = timestep_embed - elif self.timestep_cond_type == "input_concat": - x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1) - - # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer - if self.global_cond_type == "prepend" and global_embed is not None: - if prepend_inputs is None: - # Prepend inputs are just the global embed, and the mask is all ones - if len(global_embed.shape) == 2: - prepend_inputs = global_embed.unsqueeze(1) - else: - prepend_inputs = global_embed - prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) - else: - # Prepend inputs are the prepend conditioning + the global embed - if len(global_embed.shape) == 2: - prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) - else: - prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1) - prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) - - prepend_length = prepend_inputs.shape[1] - - x = self.preprocess_conv(x) + x - x = rearrange(x, "b c t -> b t c") - - extra_args = {} - - if self.global_cond_type == "adaLN": - extra_args["global_cond"] = global_embed - - if self.patch_size > 1: - b, seq_len, c = x.shape - - # 计算需要填充的数量 - pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size - - if pad_amount > 0: - # 在时间维度上进行填充 - x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0) - x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) - - if add_cond is not None: - # Interpolate add_cond to the same length as x - # if self.use_mlp: - add_cond = self.to_add_embed(add_cond) - if add_cond.shape[1] != x.shape[1]: - add_cond = add_cond.transpose(1,2) - add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False) - add_cond = add_cond.transpose(1,2) - # add_cond = resample(add_cond, x) - - if sync_cond is not None: - sync_cond = self.to_sync_embed(sync_cond) - - if self.transformer_type == "continuous_transformer": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) - - if return_info: - output, info = output - - output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] - - if self.patch_size > 1: - output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) - # 移除之前添加的填充 - if pad_amount > 0: - output = output[:, :, :seq_len] - - output = self.postprocess_conv(output) + output - - if return_info: - return output, info - - return output - - def forward( - self, - x, - t, - cross_attn_cond=None, - cross_attn_cond_mask=None, - negative_cross_attn_cond=None, - negative_cross_attn_mask=None, - input_concat_cond=None, - global_embed=None, - negative_global_embed=None, - prepend_cond=None, - prepend_cond_mask=None, - add_cond=None, - sync_cond=None, - cfg_scale=1.0, - cfg_dropout_prob=0.0, - causal=False, - scale_phi=0.0, - mask=None, - return_info=False, - **kwargs): - - assert causal == False, "Causal mode is not supported for DiffusionTransformer" - bsz, a, b = x.shape - model_dtype = next(self.parameters()).dtype - x = x.to(model_dtype) - t = t.to(model_dtype) - - if cross_attn_cond is not None: - cross_attn_cond = cross_attn_cond.to(model_dtype) - - if negative_cross_attn_cond is not None: - negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype) - - if input_concat_cond is not None: - input_concat_cond = input_concat_cond.to(model_dtype) - - if global_embed is not None: - global_embed = global_embed.to(model_dtype) - - if negative_global_embed is not None: - negative_global_embed = negative_global_embed.to(model_dtype) - - if prepend_cond is not None: - prepend_cond = prepend_cond.to(model_dtype) - - if add_cond is not None: - add_cond = add_cond.to(model_dtype) - - if sync_cond is not None: - sync_cond = sync_cond.to(model_dtype) - - if cross_attn_cond_mask is not None: - cross_attn_cond_mask = cross_attn_cond_mask.bool() - - cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention - - if prepend_cond_mask is not None: - prepend_cond_mask = prepend_cond_mask.bool() - - - # CFG dropout - if cfg_dropout_prob > 0.0 and cfg_scale == 1.0: - if cross_attn_cond is not None: - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) - cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) - - if prepend_cond is not None: - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) - prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) - - if add_cond is not None: - null_embed = torch.zeros_like(add_cond, device=add_cond.device) - dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool) - add_cond = torch.where(dropout_mask, null_embed, add_cond) - - if sync_cond is not None: - null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) - dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool) - sync_cond = torch.where(dropout_mask, null_embed, sync_cond) - - if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None): - # Classifier-free guidance - # Concatenate conditioned and unconditioned inputs on the batch dimension - batch_inputs = torch.cat([x, x], dim=0) - batch_timestep = torch.cat([t, t], dim=0) - if global_embed is not None and global_embed.shape[0] == bsz: - batch_global_cond = torch.cat([global_embed, global_embed], dim=0) - elif global_embed is not None: - batch_global_cond = global_embed - else: - batch_global_cond = None - - if input_concat_cond is not None and input_concat_cond.shape[0] == bsz: - batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) - elif input_concat_cond is not None: - batch_input_concat_cond = input_concat_cond - else: - batch_input_concat_cond = None - - batch_cond = None - batch_cond_masks = None - - # Handle CFG for cross-attention conditioning - if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz: - - null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) - - # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning - if negative_cross_attn_cond is not None: - - # If there's a negative cross-attention mask, set the masked tokens to the null embed - if negative_cross_attn_mask is not None: - negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) - - negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) - - batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) - - else: - batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) - - if cross_attn_cond_mask is not None: - batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) - elif cross_attn_cond is not None: - batch_cond = cross_attn_cond - else: - batch_cond = None - - batch_prepend_cond = None - batch_prepend_cond_mask = None - - if prepend_cond is not None and prepend_cond.shape[0] == bsz: - - null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) - - batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) - - if prepend_cond_mask is not None: - batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) - elif prepend_cond is not None: - batch_prepend_cond = prepend_cond - else: - batch_prepend_cond = None - - batch_add_cond = None - - # Handle CFG for cross-attention conditioning - if add_cond is not None and add_cond.shape[0] == bsz: - - null_embed = torch.zeros_like(add_cond, device=add_cond.device) - - - batch_add_cond = torch.cat([add_cond, null_embed], dim=0) - elif add_cond is not None: - batch_add_cond = add_cond - else: - batch_add_cond = None - - batch_sync_cond = None - - # Handle CFG for cross-attention conditioning - if sync_cond is not None and sync_cond.shape[0] == bsz: - - null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) - - - batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0) - elif sync_cond is not None: - batch_sync_cond = sync_cond - else: - batch_sync_cond = None - - if mask is not None: - batch_masks = torch.cat([mask, mask], dim=0) - else: - batch_masks = None - - batch_output = self._forward( - batch_inputs, - batch_timestep, - cross_attn_cond=batch_cond, - cross_attn_cond_mask=batch_cond_masks, - mask = batch_masks, - input_concat_cond=batch_input_concat_cond, - global_embed = batch_global_cond, - prepend_cond = batch_prepend_cond, - prepend_cond_mask = batch_prepend_cond_mask, - add_cond = batch_add_cond, - sync_cond = batch_sync_cond, - return_info = return_info, - **kwargs) - - if return_info: - batch_output, info = batch_output - - cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) - cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale - - # CFG Rescale - if scale_phi != 0.0: - cond_out_std = cond_output.std(dim=1, keepdim=True) - out_cfg_std = cfg_output.std(dim=1, keepdim=True) - output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output - else: - output = cfg_output - - if return_info: - return output, info - - return output - - else: - return self._forward( - x, - t, - cross_attn_cond=cross_attn_cond, - cross_attn_cond_mask=cross_attn_cond_mask, - input_concat_cond=input_concat_cond, - global_embed=global_embed, - prepend_cond=prepend_cond, - prepend_cond_mask=prepend_cond_mask, - add_cond=add_cond, - sync_cond=sync_cond, - mask=mask, - return_info=return_info, - **kwargs - ) \ No newline at end of file diff --git a/prismaudio_core/models/local_attention.py b/prismaudio_core/models/local_attention.py deleted file mode 100644 index 5d6aa7d..0000000 --- a/prismaudio_core/models/local_attention.py +++ /dev/null @@ -1,275 +0,0 @@ -import torch - -from einops import rearrange -from torch import nn - -from .blocks import AdaRMSNorm -from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm -from .utils import checkpoint - -# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py -class ContinuousLocalTransformer(nn.Module): - def __init__( - self, - *, - dim, - depth, - dim_in = None, - dim_out = None, - causal = False, - local_attn_window_size = 64, - heads = 8, - ff_mult = 2, - cond_dim = 0, - cross_attn_cond_dim = 0, - **kwargs - ): - super().__init__() - - dim_head = dim//heads - - self.layers = nn.ModuleList([]) - - self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() - - self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() - - self.local_attn_window_size = local_attn_window_size - - self.cond_dim = cond_dim - - self.cross_attn_cond_dim = cross_attn_cond_dim - - self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) - - for _ in range(depth): - - self.layers.append(nn.ModuleList([ - AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), - Attention( - dim=dim, - dim_heads=dim_head, - causal=causal, - zero_init_output=True, - natten_kernel_size=local_attn_window_size, - ), - Attention( - dim=dim, - dim_heads=dim_head, - dim_context = cross_attn_cond_dim, - zero_init_output=True - ) if self.cross_attn_cond_dim > 0 else nn.Identity(), - AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), - FeedForward(dim = dim, mult = ff_mult, no_bias=True) - ])) - - def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): - - x = checkpoint(self.project_in, x) - - if prepend_cond is not None: - x = torch.cat([prepend_cond, x], dim=1) - - pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) - - for attn_norm, attn, xattn, ff_norm, ff in self.layers: - - residual = x - if cond is not None: - x = checkpoint(attn_norm, x, cond) - else: - x = checkpoint(attn_norm, x) - - x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual - - if cross_attn_cond is not None: - x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x - - residual = x - - if cond is not None: - x = checkpoint(ff_norm, x, cond) - else: - x = checkpoint(ff_norm, x) - - x = checkpoint(ff, x) + residual - - return checkpoint(self.project_out, x) - -class TransformerDownsampleBlock1D(nn.Module): - def __init__( - self, - in_channels, - embed_dim = 768, - depth = 3, - heads = 12, - downsample_ratio = 2, - local_attn_window_size = 64, - **kwargs - ): - super().__init__() - - self.downsample_ratio = downsample_ratio - - self.transformer = ContinuousLocalTransformer( - dim=embed_dim, - depth=depth, - heads=heads, - local_attn_window_size=local_attn_window_size, - **kwargs - ) - - self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() - - self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) - - - def forward(self, x): - - x = checkpoint(self.project_in, x) - - # Compute - x = self.transformer(x) - - # Trade sequence length for channels - x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) - - # Project back to embed dim - x = checkpoint(self.project_down, x) - - return x - -class TransformerUpsampleBlock1D(nn.Module): - def __init__( - self, - in_channels, - embed_dim, - depth = 3, - heads = 12, - upsample_ratio = 2, - local_attn_window_size = 64, - **kwargs - ): - super().__init__() - - self.upsample_ratio = upsample_ratio - - self.transformer = ContinuousLocalTransformer( - dim=embed_dim, - depth=depth, - heads=heads, - local_attn_window_size = local_attn_window_size, - **kwargs - ) - - self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() - - self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) - - def forward(self, x): - - # Project to embed dim - x = checkpoint(self.project_in, x) - - # Project to increase channel dim - x = checkpoint(self.project_up, x) - - # Trade channels for sequence length - x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) - - # Compute - x = self.transformer(x) - - return x - - -class TransformerEncoder1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - embed_dims = [96, 192, 384, 768], - heads = [12, 12, 12, 12], - depths = [3, 3, 3, 3], - ratios = [2, 2, 2, 2], - local_attn_window_size = 64, - **kwargs - ): - super().__init__() - - layers = [] - - for layer in range(len(depths)): - prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] - - layers.append( - TransformerDownsampleBlock1D( - in_channels = prev_dim, - embed_dim = embed_dims[layer], - heads = heads[layer], - depth = depths[layer], - downsample_ratio = ratios[layer], - local_attn_window_size = local_attn_window_size, - **kwargs - ) - ) - - self.layers = nn.Sequential(*layers) - - self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) - self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) - - def forward(self, x): - x = rearrange(x, "b c n -> b n c") - x = checkpoint(self.project_in, x) - x = self.layers(x) - x = checkpoint(self.project_out, x) - x = rearrange(x, "b n c -> b c n") - - return x - - -class TransformerDecoder1D(nn.Module): - def __init__( - self, - in_channels, - out_channels, - embed_dims = [768, 384, 192, 96], - heads = [12, 12, 12, 12], - depths = [3, 3, 3, 3], - ratios = [2, 2, 2, 2], - local_attn_window_size = 64, - **kwargs - ): - - super().__init__() - - layers = [] - - for layer in range(len(depths)): - prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] - - layers.append( - TransformerUpsampleBlock1D( - in_channels = prev_dim, - embed_dim = embed_dims[layer], - heads = heads[layer], - depth = depths[layer], - upsample_ratio = ratios[layer], - local_attn_window_size = local_attn_window_size, - **kwargs - ) - ) - - self.layers = nn.Sequential(*layers) - - self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) - self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) - - def forward(self, x): - x = rearrange(x, "b c n -> b n c") - x = checkpoint(self.project_in, x) - x = self.layers(x) - x = checkpoint(self.project_out, x) - x = rearrange(x, "b n c -> b c n") - return x \ No newline at end of file diff --git a/prismaudio_core/models/mmmodules/__init__.py b/prismaudio_core/models/mmmodules/__init__.py deleted file mode 100644 index 520434f..0000000 --- a/prismaudio_core/models/mmmodules/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# mmmodules package diff --git a/prismaudio_core/models/mmmodules/model/__init__.py b/prismaudio_core/models/mmmodules/model/__init__.py deleted file mode 100644 index a86d6e1..0000000 --- a/prismaudio_core/models/mmmodules/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# mmmodules.model package diff --git a/prismaudio_core/models/mmmodules/model/low_level.py b/prismaudio_core/models/mmmodules/model/low_level.py deleted file mode 100644 index c8326a8..0000000 --- a/prismaudio_core/models/mmmodules/model/low_level.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - - -class ChannelLastConv1d(nn.Conv1d): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.permute(0, 2, 1) - x = super().forward(x) - x = x.permute(0, 2, 1) - return x - - -# https://github.com/Stability-AI/sd3-ref -class MLP(nn.Module): - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - ): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - - Attributes: - w1 (ColumnParallelLinear): Linear transformation for the first layer. - w2 (RowParallelLinear): Linear transformation for the second layer. - w3 (ColumnParallelLinear): Linear transformation for the third layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) - - -class ConvMLP(nn.Module): - - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int = 256, - kernel_size: int = 3, - padding: int = 1, - ): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - - Attributes: - w1 (ColumnParallelLinear): Linear transformation for the first layer. - w2 (RowParallelLinear): Linear transformation for the second layer. - w3 (ColumnParallelLinear): Linear transformation for the third layer. - - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = ChannelLastConv1d(dim, - hidden_dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - self.w2 = ChannelLastConv1d(hidden_dim, - dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - self.w3 = ChannelLastConv1d(dim, - hidden_dim, - bias=False, - kernel_size=kernel_size, - padding=padding) - - def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/prismaudio_core/models/pqmf.py b/prismaudio_core/models/pqmf.py deleted file mode 100644 index 007fdb5..0000000 --- a/prismaudio_core/models/pqmf.py +++ /dev/null @@ -1,393 +0,0 @@ -import math -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from scipy.optimize import fmin -from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord - -class PQMF(nn.Module): - """ - Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. - Uses polyphase representation which is computationally more efficient for real-time. - - Parameters: - - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. - - num_bands (int): Number of desired frequency bands. It must be a power of 2. - """ - - def __init__(self, attenuation, num_bands): - super(PQMF, self).__init__() - - # Ensure num_bands is a power of 2 - is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) - assert is_power_of_2, "'num_bands' must be a power of 2." - - # Create the prototype filter - prototype_filter = design_prototype_filter(attenuation, num_bands) - filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) - padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) - - # Register filters and settings - self.register_buffer("filter_bank", padded_filter_bank) - self.register_buffer("prototype", prototype_filter) - self.num_bands = num_bands - - def forward(self, signal): - """Decompose the signal into multiple frequency bands.""" - # If signal is not a pytorch tensor of Batch x Channels x Length, convert it - signal = prepare_signal_dimensions(signal) - # The signal length must be a multiple of num_bands. Pad it with zeros. - signal = pad_signal(signal, self.num_bands) - # run it - signal = polyphase_analysis(signal, self.filter_bank) - return apply_alias_cancellation(signal) - - def inverse(self, bands): - """Reconstruct the original signal from the frequency bands.""" - bands = apply_alias_cancellation(bands) - return polyphase_synthesis(bands, self.filter_bank) - - -def prepare_signal_dimensions(signal): - """ - Rearrange signal into Batch x Channels x Length. - - Parameters - ---------- - signal : torch.Tensor or numpy.ndarray - The input signal. - - Returns - ------- - torch.Tensor - Preprocessed signal tensor. - """ - # Convert numpy to torch tensor - if isinstance(signal, np.ndarray): - signal = torch.from_numpy(signal) - - # Ensure tensor - if not isinstance(signal, torch.Tensor): - raise ValueError("Input should be either a numpy array or a PyTorch tensor.") - - # Modify dimension of signal to Batch x Channels x Length - if signal.dim() == 1: - # This is just a mono signal. Unsqueeze to 1 x 1 x Length - signal = signal.unsqueeze(0).unsqueeze(0) - elif signal.dim() == 2: - # This is a multi-channel signal (e.g. stereo) - # Rearrange so that larger dimension (Length) is last - if signal.shape[0] > signal.shape[1]: - signal = signal.T - # Unsqueeze to 1 x Channels x Length - signal = signal.unsqueeze(0) - return signal - -def pad_signal(signal, num_bands): - """ - Pads the signal to make its length divisible by the given number of bands. - - Parameters - ---------- - signal : torch.Tensor - The input signal tensor, where the last dimension represents the signal length. - - num_bands : int - The number of bands by which the signal length should be divisible. - - Returns - ------- - torch.Tensor - The padded signal tensor. If the original signal length was already divisible - by num_bands, returns the original signal unchanged. - """ - remainder = signal.shape[-1] % num_bands - if remainder > 0: - padding_size = num_bands - remainder - signal = nn.functional.pad(signal, (0, padding_size)) - return signal - -def generate_modulated_filter_bank(prototype_filter, num_bands): - """ - Generate a QMF bank of cosine modulated filters based on a given prototype filter. - - Parameters - ---------- - prototype_filter : torch.Tensor - The prototype filter used as the basis for modulation. - num_bands : int - The number of desired subbands or filters. - - Returns - ------- - torch.Tensor - A bank of cosine modulated filters. - """ - - # Initialize indices for modulation. - subband_indices = torch.arange(num_bands).reshape(-1, 1) - - # Calculate the length of the prototype filter. - filter_length = prototype_filter.shape[-1] - - # Generate symmetric time indices centered around zero. - time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) - - # Calculate phase offsets to ensure orthogonality between subbands. - phase_offsets = (-1)**subband_indices * np.pi / 4 - - # Compute the cosine modulation function. - modulation = torch.cos( - (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets - ) - - # Apply modulation to the prototype filter. - modulated_filters = 2 * prototype_filter * modulation - - return modulated_filters - - -def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): - """ - Design a lowpass filter using the Kaiser window. - - Parameters - ---------- - angular_cutoff : float - The angular frequency cutoff of the filter. - attenuation : float - The desired stopband attenuation in decibels (dB). - filter_length : int, optional - Desired length of the filter. If not provided, it's computed based on the given specs. - - Returns - ------- - ndarray - The designed lowpass filter coefficients. - """ - - estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) - - # Ensure the estimated length is odd. - estimated_length = 2 * (estimated_length // 2) + 1 - - if filter_length is None: - filter_length = estimated_length - - return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) - - -def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): - """ - Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 - - Parameters - ---------- - angular_cutoff : float - Angular frequency cutoff of the filter. - attenuation : float - Desired stopband attenuation in dB. - num_bands : int - Number of bands for the multiband filter system. - filter_length : int, optional - Desired length of the filter. - - Returns - ------- - float - The computed objective (loss) value for the given filter specs. - """ - - filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) - convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") - - return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) - - -def design_prototype_filter(attenuation, num_bands, filter_length=None): - """ - Design the optimal prototype filter for a multiband system given the desired specs. - - Parameters - ---------- - attenuation : float - The desired stopband attenuation in dB. - num_bands : int - Number of bands for the multiband filter system. - filter_length : int, optional - Desired length of the filter. If not provided, it's computed based on the given specs. - - Returns - ------- - ndarray - The optimal prototype filter coefficients. - """ - - optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), - 1 / num_bands, disp=0)[0] - - prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) - return torch.tensor(prototype_filter, dtype=torch.float32) - -def pad_to_nearest_power_of_two(x): - """ - Pads the input tensor 'x' on both sides such that its last dimension - becomes the nearest larger power of two. - - Parameters: - ----------- - x : torch.Tensor - The input tensor to be padded. - - Returns: - -------- - torch.Tensor - The padded tensor. - """ - current_length = x.shape[-1] - target_length = 2**math.ceil(math.log2(current_length)) - - total_padding = target_length - current_length - left_padding = total_padding // 2 - right_padding = total_padding - left_padding - - return nn.functional.pad(x, (left_padding, right_padding)) - -def apply_alias_cancellation(x): - """ - Applies alias cancellation by inverting the sign of every - second element of every second row, starting from the second - row's first element in a tensor. - - This operation helps ensure that the aliasing introduced in - each band during the decomposition will be counteracted during - the reconstruction. - - Parameters: - ----------- - x : torch.Tensor - The input tensor. - - Returns: - -------- - torch.Tensor - Tensor with specific elements' sign inverted for alias cancellation. - """ - - # Create a mask of the same shape as 'x', initialized with all ones - mask = torch.ones_like(x) - - # Update specific elements in the mask to -1 to perform inversion - mask[..., 1::2, ::2] = -1 - - # Apply the mask to the input tensor 'x' - return x * mask - -def ensure_odd_length(tensor): - """ - Pads the last dimension of a tensor to ensure its size is odd. - - Parameters: - ----------- - tensor : torch.Tensor - Input tensor whose last dimension might need padding. - - Returns: - -------- - torch.Tensor - The original tensor if its last dimension was already odd, - or the padded tensor with an odd-sized last dimension. - """ - - last_dim_size = tensor.shape[-1] - - if last_dim_size % 2 == 0: - tensor = nn.functional.pad(tensor, (0, 1)) - - return tensor - -def polyphase_analysis(signal, filter_bank): - """ - Applies the polyphase method to efficiently analyze the signal using a filter bank. - - Parameters: - ----------- - signal : torch.Tensor - Input signal tensor with shape (Batch x Channels x Length). - - filter_bank : torch.Tensor - Filter bank tensor with shape (Bands x Length). - - Returns: - -------- - torch.Tensor - Signal split into sub-bands. (Batch x Channels x Bands x Length) - """ - - num_bands = filter_bank.shape[0] - num_channels = signal.shape[1] - - # Rearrange signal for polyphase processing. - # Also combine Batch x Channel into one dimension for now. - #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) - signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) - - # Rearrange the filter bank for matching signal shape - filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) - - # Apply convolution with appropriate padding to maintain spatial dimensions - padding = filter_bank.shape[-1] // 2 - filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) - - # Truncate the last dimension post-convolution to adjust the output shape - filtered_signal = filtered_signal[..., :-1] - # Rearrange the first dimension back into Batch x Channels - filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) - - return filtered_signal - -def polyphase_synthesis(signal, filter_bank): - """ - Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. - - Parameters - ---------- - signal : torch.Tensor - Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). - - filter_bank : torch.Tensor - Analysis filter bank (shape: Bands x Length). - - should_rearrange : bool, optional - Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. - - Returns - ------- - torch.Tensor - Reconstructed signal (shape: Batch x Channels X Length) - """ - - num_bands = filter_bank.shape[0] - num_channels = signal.shape[1] - - # Rearrange the filter bank - filter_bank = filter_bank.flip(-1) - filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) - - # Combine Batch x Channels into one dimension for now. - signal = rearrange(signal, "b c n t -> (b c) n t") - - # Apply convolution with appropriate padding - padding_amount = filter_bank.shape[-1] // 2 + 1 - reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) - - # Scale the result - reconstructed_signal = reconstructed_signal[..., :-1] * num_bands - - # Reorganize the output and truncate - reconstructed_signal = reconstructed_signal.flip(1) - reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) - reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] - - return reconstructed_signal \ No newline at end of file diff --git a/prismaudio_core/models/pretransforms.py b/prismaudio_core/models/pretransforms.py deleted file mode 100644 index 89edf3f..0000000 --- a/prismaudio_core/models/pretransforms.py +++ /dev/null @@ -1,239 +0,0 @@ -import torch -from einops import rearrange -from torch import nn - -class Pretransform(nn.Module): - def __init__(self, enable_grad, io_channels, is_discrete): - super().__init__() - - self.is_discrete = is_discrete - self.io_channels = io_channels - self.encoded_channels = None - self.downsampling_ratio = None - - self.enable_grad = enable_grad - - def encode(self, x): - raise NotImplementedError - - def decode(self, z): - raise NotImplementedError - - def tokenize(self, x): - raise NotImplementedError - - def decode_tokens(self, tokens): - raise NotImplementedError - -class AutoencoderPretransform(Pretransform): - def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): - super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) - self.model = model - self.model.requires_grad_(False).eval() - self.scale=scale - self.downsampling_ratio = model.downsampling_ratio - self.io_channels = model.io_channels - self.sample_rate = model.sample_rate - - self.model_half = model_half - self.iterate_batch = iterate_batch - - self.encoded_channels = model.latent_dim - - self.chunked = chunked - self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None - self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None - - if self.model_half: - self.model.half() - - def encode(self, x, **kwargs): - - if self.model_half: - x = x.half() - self.model.to(torch.float16) - - encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) - - if self.model_half: - encoded = encoded.float() - - return encoded / self.scale - - def decode(self, z, **kwargs): - z = z * self.scale - - if self.model_half: - z = z.half() - self.model.to(torch.float16) - - decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) - - if self.model_half: - decoded = decoded.float() - - return decoded - - def tokenize(self, x, **kwargs): - assert self.model.is_discrete, "Cannot tokenize with a continuous model" - - _, info = self.model.encode(x, return_info = True, **kwargs) - - return info[self.model.bottleneck.tokens_id] - - def decode_tokens(self, tokens, **kwargs): - assert self.model.is_discrete, "Cannot decode tokens with a continuous model" - - return self.model.decode_tokens(tokens, **kwargs) - - def load_state_dict(self, state_dict, strict=True): - self.model.load_state_dict(state_dict, strict=strict) - -class PQMFPretransform(Pretransform): - def __init__(self, attenuation=100, num_bands=16): - # TODO: Fix PQMF to take in in-channels - super().__init__(enable_grad=False, io_channels=1, is_discrete=False) - from .pqmf import PQMF - self.pqmf = PQMF(attenuation, num_bands) - - - def encode(self, x): - # x is (Batch x Channels x Time) - x = self.pqmf.forward(x) - # pqmf.forward returns (Batch x Channels x Bands x Time) - # but Pretransform needs Batch x Channels x Time - # so concatenate channels and bands into one axis - return rearrange(x, "b c n t -> b (c n) t") - - def decode(self, x): - # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) - x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) - # returns (Batch x Channels x Time) - return self.pqmf.inverse(x) - -class PretrainedDACPretransform(Pretransform): - def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): - super().__init__(enable_grad=False, io_channels=1, is_discrete=True) - - import dac - - model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) - - self.model = dac.DAC.load(model_path) - - self.quantize_on_decode = quantize_on_decode - - if model_type == "44khz": - self.downsampling_ratio = 512 - else: - self.downsampling_ratio = 320 - - self.io_channels = 1 - - self.scale = scale - - self.chunked = chunked - - self.encoded_channels = self.model.latent_dim - - self.num_quantizers = self.model.n_codebooks - - self.codebook_size = self.model.codebook_size - - def encode(self, x): - - latents = self.model.encoder(x) - - if self.quantize_on_decode: - output = latents - else: - z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) - output = z - - if self.scale != 1.0: - output = output / self.scale - - return output - - def decode(self, z): - - if self.scale != 1.0: - z = z * self.scale - - if self.quantize_on_decode: - z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) - - return self.model.decode(z) - - def tokenize(self, x): - return self.model.encode(x)[1] - - def decode_tokens(self, tokens): - latents = self.model.quantizer.from_codes(tokens) - return self.model.decode(latents) - -class AudiocraftCompressionPretransform(Pretransform): - def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): - super().__init__(enable_grad=False, io_channels=1, is_discrete=True) - - try: - from audiocraft.models import CompressionModel - except ImportError: - raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") - - self.model = CompressionModel.get_pretrained(model_type) - - self.quantize_on_decode = quantize_on_decode - - self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) - - self.sample_rate = self.model.sample_rate - - self.io_channels = self.model.channels - - self.scale = scale - - #self.encoded_channels = self.model.latent_dim - - self.num_quantizers = self.model.num_codebooks - - self.codebook_size = self.model.cardinality - - self.model.to(torch.float16).eval().requires_grad_(False) - - def encode(self, x): - - assert False, "Audiocraft compression models do not support continuous encoding" - - # latents = self.model.encoder(x) - - # if self.quantize_on_decode: - # output = latents - # else: - # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) - # output = z - - # if self.scale != 1.0: - # output = output / self.scale - - # return output - - def decode(self, z): - - assert False, "Audiocraft compression models do not support continuous decoding" - - # if self.scale != 1.0: - # z = z * self.scale - - # if self.quantize_on_decode: - # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) - - # return self.model.decode(z) - - def tokenize(self, x): - with torch.cuda.amp.autocast(enabled=False): - return self.model.encode(x.to(torch.float16))[0] - - def decode_tokens(self, tokens): - with torch.cuda.amp.autocast(enabled=False): - return self.model.decode(tokens) diff --git a/prismaudio_core/models/transformer.py b/prismaudio_core/models/transformer.py deleted file mode 100644 index 16e465a..0000000 --- a/prismaudio_core/models/transformer.py +++ /dev/null @@ -1,989 +0,0 @@ -from functools import reduce, partial -from packaging import version - -from einops import rearrange, repeat -from einops.layers.torch import Rearrange -import torch -import torch.nn.functional as F -from torch import nn, einsum -from torch.cuda.amp import autocast -from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP -from typing import Callable, Literal - -try: - from flash_attn import flash_attn_func, flash_attn_kvpacked_func - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False - flash_attn_kvpacked_func = None - flash_attn_func = None - -from .utils import compile, checkpoint -try: - import natten -except ImportError: - natten = None - -def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): - return x * (1 + scale) + shift - -# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License -# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt - -def create_causal_mask(i, j, device): - return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) - -def or_reduce(masks): - head, *body = masks - for rest in body: - head = head | rest - return head - -# positional embeddings - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - self.scale = dim ** -0.5 - self.max_seq_len = max_seq_len - self.emb = nn.Embedding(max_seq_len, dim) - - def forward(self, x, pos = None, seq_start_pos = None): - seq_len, device = x.shape[1], x.device - assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' - - if pos is None: - pos = torch.arange(seq_len, device = device) - - if seq_start_pos is not None: - pos = (pos - seq_start_pos[..., None]).clamp(min = 0) - - pos_emb = self.emb(pos) - pos_emb = pos_emb * self.scale - return pos_emb - -class ScaledSinusoidalEmbedding(nn.Module): - def __init__(self, dim, theta = 10000): - super().__init__() - assert (dim % 2) == 0, 'dimension must be divisible by 2' - self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) - - half_dim = dim // 2 - freq_seq = torch.arange(half_dim).float() / half_dim - inv_freq = theta ** -freq_seq - self.register_buffer('inv_freq', inv_freq, persistent = False) - - def forward(self, x, pos = None, seq_start_pos = None): - seq_len, device = x.shape[1], x.device - - if pos is None: - pos = torch.arange(seq_len, device = device) - - if seq_start_pos is not None: - pos = pos - seq_start_pos[..., None] - - emb = einsum('i, j -> i j', pos, self.inv_freq) - emb = torch.cat((emb.sin(), emb.cos()), dim = -1) - return emb * self.scale - -class RotaryEmbedding(nn.Module): - def __init__( - self, - dim, - use_xpos = False, - scale_base = 512, - interpolation_factor = 1., - base = 10000, - base_rescale_factor = 1. - ): - super().__init__() - # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning - # has some connection to NTK literature - # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - base *= base_rescale_factor ** (dim / (dim - 2)) - - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) - - assert interpolation_factor >= 1. - self.interpolation_factor = interpolation_factor - - if not use_xpos: - self.register_buffer('scale', None) - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - - self.scale_base = scale_base - self.register_buffer('scale', scale) - - def forward_from_seq_len(self, seq_len): - device = self.inv_freq.device - - t = torch.arange(seq_len, device = device) - return self.forward(t) - - @autocast(enabled = False) - def forward(self, t): - device = self.inv_freq.device - - t = t.to(torch.float32) - - t = t / self.interpolation_factor - - freqs = torch.einsum('i , j -> i j', t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim = -1) - - if self.scale is None: - return freqs, 1. - - power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base - scale = self.scale ** rearrange(power, 'n -> n 1') - scale = torch.cat((scale, scale), dim = -1) - - return freqs, scale - -def rotate_half(x): - x = rearrange(x, '... (j d) -> ... j d', j = 2) - x1, x2 = x.unbind(dim = -2) - return torch.cat((-x2, x1), dim = -1) - -@autocast(enabled = False) -def apply_rotary_pos_emb(t, freqs, scale = 1): - out_dtype = t.dtype - - # cast to float32 if necessary for numerical stability - dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) - rot_dim, seq_len = freqs.shape[-1], t.shape[-2] - freqs, t = freqs.to(dtype), t.to(dtype) - freqs = freqs[-seq_len:, :] - - if t.ndim == 4 and freqs.ndim == 3: - freqs = rearrange(freqs, 'b n d -> b 1 n d') - - # partial rotary embeddings, Wang et al. GPT-J - t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] - t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) - - t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) - - return torch.cat((t, t_unrotated), dim = -1) - -# norms -class DynamicTanh(nn.Module): - def __init__(self, dim, init_alpha=10.0): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1) * init_alpha) - self.gamma = nn.Parameter(torch.ones(dim)) - self.beta = nn.Parameter(torch.zeros(dim)) - - def forward(self, x): - x = F.tanh(self.alpha * x) - return self.gamma * x + self.beta - -class RunningInstanceNorm(nn.Module): - def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): - super().__init__() - self.register_buffer("running_mean", torch.zeros(1,1,dim)) - self.register_buffer("running_std", torch.ones(1,1,dim)) - self.saturate = saturate - self.eps = eps - self.momentum = momentum - self.dim = dim - self.trainable_gain = trainable_gain - if self.trainable_gain: - self.gain = nn.Parameter(torch.ones(1)) - - def _update_stats(self, x): - self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) - self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) - - def forward(self, x): - if self.training: - self._update_stats(x) - x = (x - self.running_mean) / self.running_std - if self.saturate: - x = torch.asinh(x) - if self.trainable_gain: - x = x * self.gain - return x - -class LayerNorm(nn.Module): - def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5): - """ - bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less - """ - super().__init__() - - if fix_scale: - self.register_buffer("gamma", torch.ones(dim)) - else: - self.gamma = nn.Parameter(torch.ones(dim)) - - if bias: - self.beta = nn.Parameter(torch.zeros(dim)) - else: - self.register_buffer("beta", torch.zeros(dim)) - - self.eps = eps - - self.force_fp32 = force_fp32 - - def forward(self, x): - if not self.force_fp32: - return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) - else: - output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) - return output.to(x.dtype) - -class LayerScale(nn.Module): - def __init__(self, dim, init_val = 1e-5): - super().__init__() - self.scale = nn.Parameter(torch.full([dim], init_val)) - def forward(self, x): - return x * self.scale - -class GLU(nn.Module): - def __init__( - self, - dim_in, - dim_out, - activation: Callable, - use_conv = False, - conv_kernel_size = 3, - ): - super().__init__() - self.act = activation - self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) - self.use_conv = use_conv - - def forward(self, x): - if self.use_conv: - x = rearrange(x, 'b n d -> b d n') - x = self.proj(x) - x = rearrange(x, 'b d n -> b n d') - else: - x = self.proj(x) - - x, gate = x.chunk(2, dim = -1) - return x * self.act(gate) - -class FeedForward(nn.Module): - def __init__( - self, - dim, - dim_out = None, - mult = 4, - no_bias = False, - glu = True, - use_conv = False, - conv_kernel_size = 3, - zero_init_output = True, - ): - super().__init__() - inner_dim = int(dim * mult) - - # Default to SwiGLU - - activation = nn.SiLU() - - dim_out = dim if dim_out is None else dim_out - - if glu: - linear_in = GLU(dim, inner_dim, activation) - else: - linear_in = nn.Sequential( - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), - nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), - activation - ) - - linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) - - # init last linear layer to 0 - if zero_init_output: - nn.init.zeros_(linear_out.weight) - if not no_bias: - nn.init.zeros_(linear_out.bias) - - - self.ff = nn.Sequential( - linear_in, - Rearrange('b d n -> b n d') if use_conv else nn.Identity(), - linear_out, - Rearrange('b n d -> b d n') if use_conv else nn.Identity(), - ) - - def forward(self, x): - return self.ff(x) - -class Attention(nn.Module): - def __init__( - self, - dim, - dim_heads = 64, - dim_context = None, - causal = False, - zero_init_output=True, - qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none', - differential = False, - feat_scale = False - ): - super().__init__() - self.dim = dim - self.dim_heads = dim_heads - self.differential = differential - - dim_kv = dim_context if dim_context is not None else dim - - self.num_heads = dim // dim_heads - self.kv_heads = dim_kv // dim_heads - - if dim_context is not None: - if differential: - self.to_q = nn.Linear(dim, dim * 2, bias=False) - self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) - else: - self.to_q = nn.Linear(dim, dim, bias=False) - self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) - else: - if differential: - self.to_qkv = nn.Linear(dim, dim * 5, bias=False) - else: - self.to_qkv = nn.Linear(dim, dim * 3, bias=False) - - self.to_out = nn.Linear(dim, dim, bias=False) - - if zero_init_output: - nn.init.zeros_(self.to_out.weight) - - if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']: - raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') - - self.qk_norm = qk_norm - if self.qk_norm == "ln": - self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) - self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) - elif self.qk_norm == 'rns': - self.q_norm = nn.RMSNorm(dim_heads) - self.k_norm = nn.RMSNorm(dim_heads) - elif self.qk_norm == 'dyt': - self.q_norm = DynamicTanh(dim_heads) - self.k_norm = DynamicTanh(dim_heads) - - self.sdp_kwargs = dict( - enable_flash = True, - enable_math = True, - enable_mem_efficient = True - ) - - self.feat_scale = feat_scale - - if self.feat_scale: - self.lambda_dc = nn.Parameter(torch.zeros(dim)) - self.lambda_hf = nn.Parameter(torch.zeros(dim)) - - self.causal = causal - - @compile - def apply_qk_layernorm(self, q, k): - q_type = q.dtype - k_type = k.dtype - q = self.q_norm(q).to(q_type) - k = self.k_norm(k).to(k_type) - return q, k - - - def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): - - if self.num_heads != self.kv_heads: - # Repeat interleave kv_heads to match q_heads for grouped query attention - heads_per_kv_head = self.num_heads // self.kv_heads - k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - - flash_attn_available = HAS_FLASH_ATTN - - if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): - flex_attention_block_mask = None - flex_attention_score_mod = None - - if flex_attention_block_mask is not None or flex_attention_score_mod is not None: - raise NotImplementedError( - "FlexAttention is not available in this build. " - "flex_attention_compiled is not defined. Remove flex_attention_block_mask/flex_attention_score_mod arguments." - ) - elif flash_attn_available: - fa_dtype_in = q.dtype - q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) - - if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: - q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v)) - - out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) - - out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') - else: - out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) - return out - - - #@compile - def forward( - self, - x, - context = None, - rotary_pos_emb = None, - causal = None, - flex_attention_block_mask = None, - flex_attention_score_mod = None, - flash_attn_sliding_window = None - ): - h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None - - kv_input = context if has_context else x - - if hasattr(self, 'to_q'): - # Use separate linear projections for q and k/v - if self.differential: - q, q_diff = self.to_q(x).chunk(2, dim=-1) - q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) - q = torch.stack([q, q_diff], dim = 1) - k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) - k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) - k = torch.stack([k, k_diff], dim = 1) - else: - q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) - else: - # Use fused linear projection - if self.differential: - q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) - q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) - q = torch.stack([q, q_diff], dim = 1) - k = torch.stack([k, k_diff], dim = 1) - else: - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) - - # Normalize q and k for cosine sim attention - if self.qk_norm == "l2": - q = F.normalize(q, dim=-1) - k = F.normalize(k, dim=-1) - elif self.qk_norm != "none": - q, k = self.apply_qk_layernorm(q, k) - - if rotary_pos_emb is not None: - freqs, _ = rotary_pos_emb - q_dtype = q.dtype - k_dtype = k.dtype - q = q.to(torch.float32) - k = k.to(torch.float32) - freqs = freqs.to(torch.float32) - if q.shape[-2] >= k.shape[-2]: - ratio = q.shape[-2] / k.shape[-2] - q_freqs, k_freqs = freqs, ratio * freqs - else: - ratio = k.shape[-2] / q.shape[-2] - q_freqs, k_freqs = ratio * freqs, freqs - q = apply_rotary_pos_emb(q, q_freqs) - k = apply_rotary_pos_emb(k, k_freqs) - q = q.to(v.dtype) - k = k.to(v.dtype) - - n, device = q.shape[-2], q.device - - causal = self.causal if causal is None else causal - - if n == 1 and causal: - causal = False - - if self.differential: - q, q_diff = q.unbind(dim = 1) - k, k_diff = k.unbind(dim = 1) - out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) - out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) - out = out - out_diff - else: - out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) - - # merge heads - out = rearrange(out, ' b h n d -> b n (h d)') - - # Communicate between heads - - # with autocast(enabled = False): - # out_dtype = out.dtype - # out = out.to(torch.float32) - # out = self.to_out(out).to(out_dtype) - out = self.to_out(out) - - if self.feat_scale: - out_dc = out.mean(dim=-2, keepdim=True) - out_hf = out - out_dc - - # Selectively modulate DC and high frequency components - out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf - - return out - -class ConformerModule(nn.Module): - def __init__( - self, - dim, - norm_kwargs = {}, - ): - - super().__init__() - - self.dim = dim - - self.in_norm = LayerNorm(dim, **norm_kwargs) - self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) - self.glu = GLU(dim, dim, nn.SiLU()) - self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) - self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm - self.swish = nn.SiLU() - self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) - - #@compile - def forward(self, x): - x = self.in_norm(x) - x = rearrange(x, 'b n d -> b d n') - x = self.pointwise_conv(x) - x = rearrange(x, 'b d n -> b n d') - x = self.glu(x) - x = rearrange(x, 'b n d -> b d n') - x = self.depthwise_conv(x) - x = rearrange(x, 'b d n -> b n d') - x = self.mid_norm(x) - x = self.swish(x) - x = rearrange(x, 'b n d -> b d n') - x = self.pointwise_conv_2(x) - x = rearrange(x, 'b d n -> b n d') - - return x - -class TransformerBlock(nn.Module): - def __init__( - self, - dim, - dim_heads = 64, - cross_attend = False, - dim_context = None, - global_cond_dim = None, - causal = False, - zero_init_branch_outputs = True, - conformer = False, - layer_ix = -1, - remove_norms = False, - add_rope = False, - layer_scale = False, - use_sync_block_film = False, - attn_kwargs = {}, - ff_kwargs = {}, - norm_kwargs = {} - ): - - super().__init__() - self.dim = dim - self.dim_heads = min(dim_heads,dim) - self.cross_attend = cross_attend - self.dim_context = dim_context - self.causal = causal - if layer_scale and zero_init_branch_outputs: - zero_init_branch_outputs = False - - self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) - - self.add_rope = add_rope - - self.self_attn = Attention( - dim, - dim_heads = self.dim_heads, - causal = causal, - zero_init_output=zero_init_branch_outputs, - **attn_kwargs - ) - - self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() - - self.cross_attend = cross_attend - if cross_attend: - self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) - self.cross_attn = Attention( - dim, - dim_heads = self.dim_heads, - dim_context=dim_context, - causal = causal, - zero_init_output=zero_init_branch_outputs, - **attn_kwargs - ) - self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() - - self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) - self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) - self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() - - self.layer_ix = layer_ix - - self.conformer = None - if conformer: - self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) - self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() - - self.global_cond_dim = global_cond_dim - if global_cond_dim is not None: - self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) - - self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None - - if use_sync_block_film: - self.sync_film_generator = nn.Sequential( - nn.Linear(dim, dim, bias=False), - nn.SiLU(), - nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift - ) - - @compile - def forward( - self, - x, - context = None, - global_cond=None, - rotary_pos_emb = None, - self_attention_block_mask = None, - self_attention_score_mod = None, - cross_attention_block_mask = None, - cross_attention_score_mod = None, - self_attention_flash_sliding_window = None, - cross_attention_flash_sliding_window = None, - sync_cond = None, - prepend_length=0 - ): - if rotary_pos_emb is None and self.add_rope: - rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) - - if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: - if len(global_cond.shape) == 2: - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) - else: - - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1) - - # self-attention with adaLN - residual = x - x = self.pre_norm(x) - x = x * (1 + scale_self) + shift_self - x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) - x = x * torch.sigmoid(1 - gate_self) - x = self.self_attn_scale(x) - x = x + residual - - if context is not None and self.cross_attend: - x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) - - if self.conformer is not None: - x = x + self.conformer_scale(self.conformer(x)) - - if sync_cond is not None and hasattr(self, 'sync_film_generator'): - scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) - x = x * (1 + scale) + shift - - # feedforward with adaLN - residual = x - x = self.ff_norm(x) - x = x * (1 + scale_ff) + shift_ff - x = self.ff(x) - x = x * torch.sigmoid(1 - gate_ff) - x = self.ff_scale(x) - x = x + residual - - else: - x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) - - if context is not None and self.cross_attend: - x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) - - if self.conformer is not None: - x = x + self.conformer_scale(self.conformer(x)) - - if sync_cond is not None and hasattr(self, 'sync_film_generator'): - prepend_part = x[:, :prepend_length, :] - audio_part = x[:, prepend_length:, :] - scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) - modulated_audio_part = audio_part * (1 + scale) + shift - x = torch.cat([prepend_part, modulated_audio_part], dim=1) - - x = x + self.ff_scale(self.ff(self.ff_norm(x))) - return x - -class ContinuousTransformer(nn.Module): - def __init__( - self, - dim, - depth, - *, - dim_in = None, - dim_out = None, - dim_heads = 64, - cross_attend=False, - cond_token_dim=None, - pre_cross_attn_ix=-1, - final_cross_attn_ix=-1, - global_cond_dim=None, - causal=False, - rotary_pos_emb=True, - zero_init_branch_outputs=True, - conformer=False, - use_sinusoidal_emb=False, - use_abs_pos_emb=False, - abs_pos_emb_max_length=10000, - num_memory_tokens=0, - sliding_window=None, - use_mlp=False, - use_add_norm=False, - use_gated=False, - use_final_layer=False, - use_zeros=False, - use_conv=False, - use_fusion_mlp=False, - use_film=False, - use_sync_film=False, - use_sync_gated=False, - **kwargs - ): - - super().__init__() - - self.dim = dim - self.depth = depth - self.causal = causal - self.layers = nn.ModuleList([]) - if use_mlp: - self.project_in = nn.Sequential( - nn.Linear(dim_in, dim, bias=False), - nn.SiLU(), - nn.Linear(dim, dim, bias=False) - ) - else: - self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() - self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() - self.video_temporal_conv = None - self.audio_temporal_conv = None - self.fusion_mlp = None - if use_conv: - self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) - self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) - if use_fusion_mlp: - self.fusion_mlp = nn.Sequential( - nn.Linear(dim, dim), - nn.SiLU(), - nn.Linear(dim, dim) - ) - - if rotary_pos_emb: - self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) - else: - self.rotary_pos_emb = None - self.num_memory_tokens = num_memory_tokens - if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) - - self.use_sinusoidal_emb = use_sinusoidal_emb - if use_sinusoidal_emb: - self.pos_emb = ScaledSinusoidalEmbedding(dim) - - self.use_abs_pos_emb = use_abs_pos_emb - if use_abs_pos_emb: - self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) - - self.adaLN_modulation = None - if global_cond_dim is not None: - if use_final_layer: - self.norm_final = LayerNorm(dim) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear( - dim, 2 * dim, bias=True - ), - ) - - if use_zeros: - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.project_out.weight, 0) - self.global_cond_embedder = nn.Sequential( - nn.Linear(global_cond_dim, dim), - nn.SiLU(), - nn.Linear(dim, dim * 6) - ) - if use_zeros: - nn.init.constant_(self.global_cond_embedder[-1].weight, 0) - nn.init.constant_(self.global_cond_embedder[-1].bias, 0) - nn.init.constant_(self.global_cond_embedder[0].weight, 0) - nn.init.constant_(self.global_cond_embedder[0].bias, 0) - - self.final_cross_attn_ix = final_cross_attn_ix - self.use_gated = use_gated - self.use_film = use_film - self.use_add_norm = use_add_norm - if self.use_add_norm: - self.add_norm = nn.LayerNorm(dim) - if use_gated: - self.gate = nn.Parameter(torch.ones(1, 1, dim)) - - if use_film: - self.film_generator = nn.Sequential( - nn.Linear(dim, dim, bias=False), - nn.SiLU(), - nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift - ) - else: - self.film_generator = None - - if use_sync_film: - self.sync_film_generator = nn.Sequential( - nn.Linear(dim, dim, bias=False), - nn.SiLU(), - nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift - ) - else: - self.sync_film_generator = None - if use_sync_gated: - self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim)) - else: - self.sync_gate = None - - self.sliding_window = sliding_window - - for i in range(depth): - should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix)) - # print(f"Layer {i} cross attends: {should_cross_attend}") - self.layers.append( - TransformerBlock( - dim, - dim_heads = dim_heads, - cross_attend = should_cross_attend, - dim_context = cond_token_dim, - global_cond_dim = global_cond_dim, - causal = causal, - zero_init_branch_outputs = zero_init_branch_outputs, - conformer=conformer, - layer_ix=i, - **kwargs - ) - ) - - def forward( - self, - x, - mask = None, - prepend_embeds = None, - prepend_mask = None, - add_cond = None, - sync_cond = None, - global_cond = None, - return_info = False, - use_checkpointing = True, - exit_layer_ix = None, - video_dropout_prob = 0.0, - **kwargs - ): - batch, seq, device = *x.shape[:2], x.device - model_dtype = next(self.parameters()).dtype - x = x.to(model_dtype) - - prepend_length = 0 - - info = { - "hidden_states": [], - } - - x = self.project_in(x) - if add_cond is not None: - if self.use_gated: - gate = torch.sigmoid(self.gate) - x = x + gate * add_cond - elif self.use_film: - scale, shift = self.film_generator(add_cond).chunk(2, dim=-1) - x = x * (1 + scale) + shift - else: - x = x + add_cond - - if self.use_add_norm: - x = self.add_norm(x) - if self.fusion_mlp is not None: - x = self.fusion_mlp(x) - - if sync_cond is not None: - # Resample sync_cond to match audio sequence length if needed - if sync_cond.shape[1] != x.shape[1]: - sync_cond = torch.nn.functional.interpolate( - sync_cond.transpose(1, 2), size=x.shape[1], - mode='linear', align_corners=False, - ).transpose(1, 2) - if self.sync_film_generator is not None: - scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) - x = x * (1 + scale) + shift - elif self.sync_gate is not None: - gate_value = torch.sigmoid(self.sync_gate) - x = x + gate_value * sync_cond - # else: - # x = x + sync_cond - - if prepend_embeds is not None: - prepend_length, prepend_dim = prepend_embeds.shape[1:] - - assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' - - x = torch.cat((prepend_embeds, x), dim = -2) - - if self.num_memory_tokens > 0: - memory_tokens = self.memory_tokens.expand(batch, -1, -1) - x = torch.cat((memory_tokens, x), dim=1) - - if self.rotary_pos_emb is not None: - rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) - else: - rotary_pos_emb = None - - if self.use_sinusoidal_emb or self.use_abs_pos_emb: - x = x + self.pos_emb(x) - - if global_cond is not None and self.global_cond_embedder is not None: - global_cond_embed = self.global_cond_embedder(global_cond) - else: - global_cond_embed = global_cond - # Iterate over the transformer layers - for layer_ix, layer in enumerate(self.layers): - if use_checkpointing: - x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) - else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) - - if return_info: - info["hidden_states"].append(x) - - if exit_layer_ix is not None and layer_ix == exit_layer_ix: - x = x[:, self.num_memory_tokens:, :] - - if return_info: - return x, info - - return x - - x = x[:, self.num_memory_tokens:, :] - if global_cond is not None and self.adaLN_modulation is not None: - if len(global_cond.shape) == 2: - global_cond = global_cond.unsqueeze(1) - shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.project_out(x) - - if return_info: - return x, info - - return x diff --git a/prismaudio_core/models/utils.py b/prismaudio_core/models/utils.py deleted file mode 100644 index c2c4f4f..0000000 --- a/prismaudio_core/models/utils.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -from safetensors.torch import load_file -from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor -#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline -from torch.nn.utils import remove_weight_norm - -def load_ckpt_state_dict(ckpt_path, prefix=None): - if ckpt_path.endswith(".safetensors"): - state_dict = load_file(ckpt_path) - else: - state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] - - # 过滤特定前缀的state_dict - filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict - - return filtered_state_dict - -def remove_weight_norm_from_model(model): - for module in model.modules(): - if hasattr(module, "weight"): - remove_weight_norm(module) - - return model - -# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license -# License can be found in LICENSES/LICENSE_META.txt - -def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): - """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. - - Args: - input (torch.Tensor): The input tensor containing probabilities. - num_samples (int): Number of samples to draw. - replacement (bool): Whether to draw with replacement or not. - Keywords args: - generator (torch.Generator): A pseudorandom number generator for sampling. - Returns: - torch.Tensor: Last dimension contains num_samples indices - sampled from the multinomial probability distribution - located in the last dimension of tensor input. - """ - - if num_samples == 1: - q = torch.empty_like(input).exponential_(1, generator=generator) - return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) - - input_ = input.reshape(-1, input.shape[-1]) - output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) - output = output_.reshape(*list(input.shape[:-1]), -1) - return output - - -def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: - """Sample next token from top K values along the last dimension of the input probs tensor. - - Args: - probs (torch.Tensor): Input probabilities with token candidates on the last dimension. - k (int): The k in “top-k”. - Returns: - torch.Tensor: Sampled tokens. - """ - top_k_value, _ = torch.topk(probs, k, dim=-1) - min_value_top_k = top_k_value[..., [-1]] - probs *= (probs >= min_value_top_k).float() - probs.div_(probs.sum(dim=-1, keepdim=True)) - next_token = multinomial(probs, num_samples=1) - return next_token - - -def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: - """Sample next token from top P probabilities along the last dimension of the input probs tensor. - - Args: - probs (torch.Tensor): Input probabilities with token candidates on the last dimension. - p (int): The p in “top-p”. - Returns: - torch.Tensor: Sampled tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort *= (~mask).float() - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token - -def next_power_of_two(n): - return 2 ** (n - 1).bit_length() - -def next_multiple_of_64(n): - return ((n + 63) // 64) * 64 - - -# mask construction helpers - -def mask_from_start_end_indices( - seq_len: int, - start: Tensor, - end: Tensor -): - assert start.shape == end.shape - device = start.device - - seq = torch.arange(seq_len, device = device, dtype = torch.long) - seq = seq.reshape(*((-1,) * start.ndim), seq_len) - seq = seq.expand(*start.shape, seq_len) - - mask = seq >= start[..., None].long() - mask &= seq < end[..., None].long() - return mask - -def mask_from_frac_lengths( - seq_len: int, - frac_lengths: Tensor -): - device = frac_lengths.device - - lengths = (frac_lengths * seq_len).long() - max_start = seq_len - lengths - - rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) - start = (max_start * rand).clamp(min = 0) - end = start + lengths - - return mask_from_start_end_indices(seq_len, start, end) - -def _build_spline(video_feat, video_t, target_t): - # 三次样条插值核心实现 - coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1)) - spline = NaturalCubicSpline(coeffs) - return spline.evaluate(target_t).permute(0,2,1) - -def resample(video_feat, audio_latent): - """ - 9s - video_feat: [B, 72, D] - audio_latent: [B, D', 194] or int - """ - B, Tv, D = video_feat.shape - - if isinstance(audio_latent, torch.Tensor): - # audio_latent is a tensor - if audio_latent.shape[1] != 64: - Ta = audio_latent.shape[1] - else: - Ta = audio_latent.shape[2] - elif isinstance(audio_latent, int): - # audio_latent is an int - Ta = audio_latent - else: - raise TypeError("audio_latent must be either a tensor or an int") - - # 构建时间戳 (关键改进点) - video_time = torch.linspace(0, 9, Tv, device=video_feat.device) - audio_time = torch.linspace(0, 9, Ta, device=video_feat.device) - - # 三维化处理 (Batch, Feature, Time) - video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv] - - # 三次样条插值 - aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] - return aligned_video.permute(0, 2, 1) # [B, Ta, D] - -def checkpoint(function, *args, **kwargs): - kwargs.setdefault("use_reentrant", False) - return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) - -import os -enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1" - -def compile(function, *args, **kwargs): - - if enable_torch_compile: - try: - return torch.compile(function, *args, **kwargs) - except RuntimeError: - return function - - return function \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3b237d9..16f3b94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,5 @@ einops>=0.7.0 -einops-exts -safetensors huggingface_hub transformers>=4.52.3 -k-diffusion>=0.1.1 -alias-free-torch -descript-audio-codec -vector-quantize-pytorch scipy tqdm diff --git a/scripts/environment.yml b/scripts/environment.yml deleted file mode 100644 index f098156..0000000 --- a/scripts/environment.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: prismaudio-extract -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - ffmpeg<7 - - pip: - - torch>=2.6.0 - - torchaudio>=2.6.0 - - torchvision>=0.21.0 - - tensorflow-cpu==2.15.0 - - jax - - jaxlib - - transformers>=4.52.3 - - decord - - einops>=0.7.0 - - numpy - - mediapy - - git+https://github.com/google-deepmind/videoprism.git diff --git a/scripts/extract_features.py b/scripts/extract_features.py deleted file mode 100755 index 00303a5..0000000 --- a/scripts/extract_features.py +++ /dev/null @@ -1,168 +0,0 @@ -#!/usr/bin/env python3 -""" -Standalone PrismAudio feature extraction script. -Runs in a separate Python env with JAX/TF installed (auto-created by PrismAudioFeatureExtractor). - -Usage: - python extract_features.py --video input.mp4 --cot_text "description..." --output features.npz -""" - -import argparse -import os -import sys -import time -import numpy as np -import torch - -# Add plugin root to sys.path so data_utils (and prismaudio_core) are importable -_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -_PLUGIN_DIR = os.path.dirname(_SCRIPT_DIR) -if _PLUGIN_DIR not in sys.path: - sys.path.insert(0, _PLUGIN_DIR) - - -def _step(n, total, label): - """Print step header and return start time.""" - print(f"[extract] Step {n}/{total} — {label}...", flush=True) - return time.perf_counter() - - -def _done(t0, extra=""): - elapsed = time.perf_counter() - t0 - suffix = f" {extra}" if extra else "" - print(f"[extract] done in {elapsed:.1f}s{suffix}", flush=True) - - -def main(): - t_total = time.perf_counter() - - parser = argparse.ArgumentParser(description="PrismAudio feature extraction") - parser.add_argument("--video", required=True, help="Path to input video") - parser.add_argument("--cot_text", required=True, help="Chain-of-thought description") - parser.add_argument("--output", required=True, help="Output .npz path") - parser.add_argument("--synchformer_ckpt", default=None, help="Path to synchformer checkpoint") - parser.add_argument("--vae_config", default=None, help="Path to VAE config JSON") - parser.add_argument("--source_fps", type=float, default=30.0, help="Original video fps (used when --video is a .npy file)") - parser.add_argument("--clip_fps", type=float, default=4.0) - parser.add_argument("--clip_size", type=int, default=288) - parser.add_argument("--sync_fps", type=float, default=25.0) - parser.add_argument("--sync_size", type=int, default=224) - args = parser.parse_args() - - print(f"[extract] Python : {sys.executable}", flush=True) - print(f"[extract] Video : {args.video}", flush=True) - print(f"[extract] Output : {args.output}", flush=True) - print(f"[extract] CoT text : {args.cot_text[:80]}{'...' if len(args.cot_text) > 80 else ''}", flush=True) - - if not os.path.exists(args.video): - print(f"[extract] ERROR: video not found: {args.video}", flush=True) - sys.exit(1) - - print(f"[extract] Device : {'cuda' if torch.cuda.is_available() else 'cpu'}", flush=True) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # ------------------------------------------------------------------ - t0 = _step(1, 6, "importing dependencies") - from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils - import torchvision.transforms as T - _done(t0) - - # ------------------------------------------------------------------ - t0 = _step(2, 6, "loading models (T5, VideoPrism, Synchformer)") - feat_utils = FeaturesUtils( - vae_config_path=args.vae_config, - synchformer_ckpt=args.synchformer_ckpt, - device=device, - ) - _done(t0) - - # ------------------------------------------------------------------ - t0 = _step(3, 6, "reading and preprocessing video") - if args.video.endswith(".npy"): - all_frames = np.load(args.video) # [T, H, W, C] uint8 - fps = args.source_fps - total_frames = all_frames.shape[0] - duration = total_frames / fps - print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] - clip_indices = [min(i, total_frames - 1) for i in clip_indices] - clip_frames = all_frames[clip_indices] - print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) - - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] - sync_indices = [min(i, total_frames - 1) for i in sync_indices] - sync_frames = all_frames[sync_indices] - print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) - else: - from decord import VideoReader, cpu - vr = VideoReader(args.video, ctx=cpu(0)) - fps = vr.get_avg_fps() - total_frames = len(vr) - duration = total_frames / fps - print(f"[extract] fps={fps:.3f} frames={total_frames} duration={duration:.2f}s", flush=True) - - clip_indices = [int(i * fps / args.clip_fps) for i in range(int(duration * args.clip_fps))] - clip_indices = [min(i, total_frames - 1) for i in clip_indices] - clip_frames = vr.get_batch(clip_indices).asnumpy() - print(f"[extract] CLIP frames : {len(clip_indices)} @ {args.clip_fps}fps → {args.clip_size}×{args.clip_size}", flush=True) - - sync_indices = [int(i * fps / args.sync_fps) for i in range(int(duration * args.sync_fps))] - sync_indices = [min(i, total_frames - 1) for i in sync_indices] - sync_frames = vr.get_batch(sync_indices).asnumpy() - print(f"[extract] Sync frames : {len(sync_indices)} @ {args.sync_fps}fps → {args.sync_size}×{args.sync_size}", flush=True) - - clip_transform = T.Compose([ - T.ToPILImage(), - T.Resize(args.clip_size), - T.CenterCrop(args.clip_size), - T.ToTensor(), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - clip_input = torch.stack([clip_transform(f) for f in clip_frames]).unsqueeze(0).to(device) - - sync_transform = T.Compose([ - T.ToPILImage(), - T.Resize(args.sync_size), - T.CenterCrop(args.sync_size), - T.ToTensor(), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - sync_input = torch.stack([sync_transform(f) for f in sync_frames]).unsqueeze(0).to(device) - _done(t0) - - # ------------------------------------------------------------------ - t0 = _step(4, 6, "encoding text with T5-Gemma") - text_features = feat_utils.encode_t5_text([args.cot_text]) - _done(t0, f"shape={tuple(text_features.shape)}") - - # ------------------------------------------------------------------ - t0 = _step(5, 6, "encoding video with VideoPrism") - global_video_features, video_features, global_text_features = \ - feat_utils.encode_video_and_text_with_videoprism(clip_input, [args.cot_text]) - _done(t0, f"video={tuple(video_features.shape)} global={tuple(global_video_features.shape)}") - - # ------------------------------------------------------------------ - t0 = _step(6, 6, "encoding video with Synchformer") - sync_features = feat_utils.encode_video_with_sync(sync_input) - _done(t0, f"shape={tuple(sync_features.shape)}") - - # ------------------------------------------------------------------ - t0 = time.perf_counter() - print(f"[extract] Saving features to {args.output} ...", flush=True) - np.savez( - args.output, - video_features=video_features.cpu().float().numpy(), - global_video_features=global_video_features.cpu().float().numpy(), - text_features=text_features.cpu().float().numpy(), - global_text_features=global_text_features.cpu().float().numpy(), - sync_features=sync_features.cpu().float().numpy(), - caption_cot=args.cot_text, - duration=duration, - ) - print(f"[extract] Saved in {time.perf_counter() - t0:.1f}s", flush=True) - print(f"[extract] Total time: {time.perf_counter() - t_total:.1f}s", flush=True) - - -if __name__ == "__main__": - main() diff --git a/scripts/install_extract_env.sh b/scripts/install_extract_env.sh deleted file mode 100755 index 1495e31..0000000 --- a/scripts/install_extract_env.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env bash -# Install the PrismAudio feature-extraction environment using pip venv. -# Use this instead of environment.yml when conda is unavailable (e.g. NVIDIA Docker). -# -# Usage: -# bash scripts/install_extract_env.sh [/path/to/venv] -# -# Default venv path: /opt/prismaudio-extract -# After installation, point the PrismAudioFeatureExtractor node's python_env to: -# /bin/python (Linux/Mac) -# \Scripts\python.exe (Windows) - -set -euo pipefail - -VENV_DIR="${1:-/opt/prismaudio-extract}" - -echo "[PrismAudio] Creating venv at: ${VENV_DIR}" -python3 -m venv "${VENV_DIR}" - -PIP="${VENV_DIR}/bin/pip" - -echo "[PrismAudio] Upgrading pip..." -"${PIP}" install --upgrade pip - -echo "[PrismAudio] Installing PyTorch stack..." -"${PIP}" install torch torchaudio torchvision - -echo "[PrismAudio] Installing feature-extraction dependencies..." -"${PIP}" install \ - "tensorflow-cpu>=2.16.0" \ - "jax[cpu]" \ - "jaxlib" \ - "transformers" \ - "decord" \ - "einops" \ - "numpy" \ - "mediapy" - -echo "[PrismAudio] Installing VideoPrism..." -"${PIP}" install "git+https://github.com/google-deepmind/videoprism.git" - -echo "" -echo "[PrismAudio] Done. Set python_env in PrismAudioFeatureExtractor to:" -echo " ${VENV_DIR}/bin/python" diff --git a/workflows/text_to_audio.json b/workflows/text_to_audio.json deleted file mode 100644 index 5ac0685..0000000 --- a/workflows/text_to_audio.json +++ /dev/null @@ -1,158 +0,0 @@ -{ - "id": "a1c3e5f7-b2d4-4e6a-8c0f-1a3b5c7d9e2f", - "revision": 0, - "last_node_id": 3, - "last_link_id": 2, - "nodes": [ - { - "id": 1, - "type": "PrismAudioModelLoader", - "pos": [ - -160, - -224 - ], - "size": [ - 288, - 96 - ], - "flags": {}, - "order": 0, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "model", - "type": "PRISMAUDIO_MODEL", - "slot_index": 0, - "links": [ - 1 - ] - } - ], - "properties": { - "aux_id": "ethanfel/ComfyUI-Prismaudio", - "ver": "62a3c5d", - "Node name for S&R": "PrismAudioModelLoader", - "ue_properties": { - "widget_ue_connectable": {}, - "version": "7.8", - "input_ue_unconnectable": {} - } - }, - "widgets_values": [ - "auto", - "auto" - ] - }, - { - "id": 2, - "type": "PrismAudioTextOnly", - "pos": [ - 192, - -224 - ], - "size": [ - 480, - 222 - ], - "flags": {}, - "order": 1, - "mode": 0, - "inputs": [ - { - "name": "model", - "type": "PRISMAUDIO_MODEL", - "link": 1 - } - ], - "outputs": [ - { - "name": "audio", - "type": "AUDIO", - "slot_index": 0, - "links": [ - 2 - ] - } - ], - "properties": { - "aux_id": "ethanfel/ComfyUI-Prismaudio", - "ver": "62a3c5d", - "Node name for S&R": "PrismAudioTextOnly", - "ue_properties": { - "widget_ue_connectable": {}, - "version": "7.8", - "input_ue_unconnectable": {} - } - }, - "widgets_values": [ - "A large dog barks sharply twice in an outdoor setting, with ambient background noise of rustling leaves and a gentle breeze. The sound is clear and close, recorded at ground level.", - 10.0, - 100, - 7.0, - 0, - "randomize" - ] - }, - { - "id": 3, - "type": "PreviewAudio", - "pos": [ - 736, - -224 - ], - "size": [ - 300, - 76 - ], - "flags": {}, - "order": 2, - "mode": 0, - "inputs": [ - { - "name": "audio", - "type": "AUDIO", - "link": 2 - } - ], - "outputs": [], - "properties": { - "Node name for S&R": "PreviewAudio" - }, - "widgets_values": [] - } - ], - "links": [ - [ - 1, - 1, - 0, - 2, - 0, - "PRISMAUDIO_MODEL" - ], - [ - 2, - 2, - 0, - 3, - 0, - "AUDIO" - ] - ], - "groups": [], - "config": {}, - "extra": { - "ds": { - "scale": 1.1674071890328979, - "offset": [ - 1814.5534800416863, - 500.0421331448515 - ] - }, - "ue_links": [], - "links_added_by_ue": [], - "frontendVersion": "1.42.8" - }, - "version": 0.4 -} diff --git a/workflows/video_to_audio.json b/workflows/video_to_audio.json deleted file mode 100644 index e70eb87..0000000 --- a/workflows/video_to_audio.json +++ /dev/null @@ -1,421 +0,0 @@ -{ - "id": "2481bfbf-ce24-46c5-abdc-1d9163ff78ae", - "revision": 0, - "last_node_id": 12, - "last_link_id": 30, - "nodes": [ - { - "id": 1, - "type": "VHS_LoadVideo", - "pos": [ - -704, - -256 - ], - "size": [ - 288, - 474.7188081936685 - ], - "flags": {}, - "order": 0, - "mode": 0, - "inputs": [ - { - "name": "meta_batch", - "shape": 7, - "type": "VHS_BatchManager", - "link": null - }, - { - "name": "vae", - "shape": 7, - "type": "VAE", - "link": null - } - ], - "outputs": [ - { - "name": "IMAGE", - "type": "IMAGE", - "slot_index": 0, - "links": [ - 12, - 20 - ] - }, - { - "name": "frame_count", - "type": "INT", - "slot_index": 1, - "links": [] - }, - { - "name": "audio", - "type": "AUDIO", - "slot_index": 2, - "links": [] - }, - { - "name": "video_info", - "type": "VHS_VIDEOINFO", - "slot_index": 3, - "links": [ - 21 - ] - } - ], - "properties": { - "cnr_id": "comfyui-videohelpersuite", - "ver": "1.7.9", - "Node name for S&R": "VHS_LoadVideo", - "ue_properties": { - "widget_ue_connectable": {}, - "version": "7.8", - "input_ue_unconnectable": {} - } - }, - "widgets_values": { - "video": "Railtransport_3_479.mp4", - "force_rate": 0, - "custom_width": 0, - "custom_height": 0, - "frame_load_cap": 0, - "skip_first_frames": 0, - "select_every_nth": 1, - "format": "AnimateDiff", - "videopreview": { - "hidden": false, - "paused": false, - "params": { - "force_rate": 0, - "frame_load_cap": 0, - "skip_first_frames": 0, - "select_every_nth": 1, - "filename": "Railtransport_3_479.mp4", - "type": "input", - "format": "video/mp4" - } - } - } - }, - { - "id": 2, - "type": "PrismAudioModelLoader", - "pos": [ - -160, - -224 - ], - "size": [ - 288, - 96 - ], - "flags": {}, - "order": 1, - "mode": 0, - "inputs": [], - "outputs": [ - { - "name": "model", - "type": "PRISMAUDIO_MODEL", - "slot_index": 0, - "links": [ - 26 - ] - } - ], - "properties": { - "aux_id": "ethanfel/ComfyUI-Prismaudio", - "ver": "3894fcc9b40a19d959614d514d5dff65cdfb6eab", - "Node name for S&R": "PrismAudioModelLoader", - "ue_properties": { - "widget_ue_connectable": {}, - "version": "7.8", - "input_ue_unconnectable": {} - } - }, - "widgets_values": [ - "auto", - "auto" - ] - }, - { - "id": 12, - "type": "PrismAudioSampler", - "pos": [ - 256, - -224 - ], - "size": [ - 384, - 224 - ], - "flags": {}, - "order": 3, - "mode": 0, - "inputs": [ - { - "name": "model", - "type": "PRISMAUDIO_MODEL", - "link": 26 - }, - { - "name": "features", - "type": "PRISMAUDIO_FEATURES", - "link": 27 - } - ], - "outputs": [ - { - "name": "audio", - "type": "AUDIO", - "links": [ - 29 - ] - } - ], - "properties": { - "aux_id": "ethanfel/ComfyUI-Prismaudio", - "ver": "30631c0cb4d97cc6aed69a52e3ee4d89df03926c", - "ue_properties": { - "widget_ue_connectable": {}, - "input_ue_unconnectable": {} - }, - "Node name for S&R": "PrismAudioSampler" - }, - "widgets_values": [ - 0, - 100, - 7, - 4096333446, - "randomize" - ] - }, - { - "id": 11, - "type": "PrismAudioFeatureExtractor", - "pos": [ - -384, - -64 - ], - "size": [ - 544, - 288 - ], - "flags": {}, - "order": 2, - "mode": 0, - "inputs": [ - { - "name": "video", - "type": "IMAGE", - "link": 20 - }, - { - "name": "video_info", - "shape": 7, - "type": "VHS_VIDEOINFO", - "link": 21 - } - ], - "outputs": [ - { - "name": "features", - "type": "PRISMAUDIO_FEATURES", - "links": [ - 27 - ] - }, - { - "name": "fps", - "type": "FLOAT", - "links": [ - 30 - ] - } - ], - "properties": { - "aux_id": "ethanfel/ComfyUI-Prismaudio", - "ver": "5b62be04471bf118b2cd3cc71431a302f5730b01", - "Node name for S&R": "PrismAudioFeatureExtractor", - "ue_properties": { - "widget_ue_connectable": {}, - "input_ue_unconnectable": {}, - "version": "7.8" - } - }, - "widgets_values": [ - "Generate ambient countryside sounds with a gentle breeze rustling the leaves of a large tree. From the right, introduce a faint rumble of wheels on a track and a steam engine chugging. Allow the sounds to grow louder and pan from right to left as the steam train travels across the landscape. Include the powerful chugging and clattering of carriages in the soundscape, then gradually recede the sounds to the left. Ensure no additional background noise or music is present.\n", - 30, - "managed_env", - "/media/unraid/comfyui/output/prismaudiocache/", - "" - ] - }, - { - "id": 9, - "type": "VHS_VideoCombine", - "pos": [ - 704, - -256 - ], - "size": [ - 384, - 552.75 - ], - "flags": {}, - "order": 4, - "mode": 0, - "inputs": [ - { - "name": "images", - "type": "IMAGE", - "link": 12 - }, - { - "name": "audio", - "shape": 7, - "type": "AUDIO", - "link": 29 - }, - { - "name": "meta_batch", - "shape": 7, - "type": "VHS_BatchManager", - "link": null - }, - { - "name": "vae", - "shape": 7, - "type": "VAE", - "link": null - }, - { - "name": "frame_rate", - "type": "FLOAT", - "widget": { - "name": "frame_rate" - }, - "link": 30 - } - ], - "outputs": [ - { - "name": "Filenames", - "type": "VHS_FILENAMES", - "links": null - } - ], - "properties": { - "cnr_id": "comfyui-videohelpersuite", - "ver": "1.7.9", - "Node name for S&R": "VHS_VideoCombine", - "ue_properties": { - "widget_ue_connectable": {}, - "input_ue_unconnectable": {}, - "version": "7.8" - } - }, - "widgets_values": { - "frame_rate": 30, - "loop_count": 0, - "filename_prefix": "AnimateDiff", - "format": "video/h264-mp4", - "pix_fmt": "yuv420p", - "crf": 19, - "save_metadata": true, - "trim_to_audio": false, - "pingpong": false, - "save_output": false, - "videopreview": { - "hidden": false, - "paused": false, - "params": { - "filename": "AnimateDiff_00001-audio.mp4", - "subfolder": "", - "type": "temp", - "format": "video/h264-mp4", - "frame_rate": 30, - "workflow": "AnimateDiff_00001.png", - "fullpath": "/basedir/temp/AnimateDiff_00001-audio.mp4" - } - } - } - } - ], - "links": [ - [ - 12, - 1, - 0, - 9, - 0, - "IMAGE" - ], - [ - 20, - 1, - 0, - 11, - 0, - "IMAGE" - ], - [ - 21, - 1, - 3, - 11, - 1, - "VHS_VIDEOINFO" - ], - [ - 26, - 2, - 0, - 12, - 0, - "PRISMAUDIO_MODEL" - ], - [ - 27, - 11, - 0, - 12, - 1, - "PRISMAUDIO_FEATURES" - ], - [ - 29, - 12, - 0, - 9, - 1, - "AUDIO" - ], - [ - 30, - 11, - 1, - 9, - 4, - "FLOAT" - ] - ], - "groups": [], - "config": {}, - "extra": { - "ds": { - "scale": 1.1674071890328979, - "offset": [ - 1814.5534800416863, - 500.0421331448515 - ] - }, - "ue_links": [], - "links_added_by_ue": [], - "frontendVersion": "1.42.8", - "VHS_latentpreview": true, - "VHS_latentpreviewrate": 0, - "VHS_MetadataImage": true, - "VHS_KeepIntermediate": true - }, - "version": 0.4 -} \ No newline at end of file