feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)

Fetch and adapt inference-critical model modules from upstream PrismAudio repo:
- dit.py: DiffusionTransformer with debug prints removed
- diffusion.py: ConditionedDiffusionModelWrapper, DiTWrapper, MMDiTWrapper
- conditioners.py: Cond_MLP, Sync_MLP, MultiConditioner with stubbed training imports
- autoencoders.py: AudioAutoencoder, OobleckEncoder/Decoder
- transformer.py: ContinuousTransformer, Attention with flash_attn fallback to SDPA
- blocks.py, utils.py, bottleneck.py, pretransforms.py, local_attention.py, pqmf.py
- adp.py: UNetCFG1d, UNet1d, NumberEmbedder
- mmmodules/model/low_level.py: MLP, ChannelLastConv1d, ConvMLP

All internal imports rewritten from PrismAudio.* to prismaudio_core.*,
training-only imports stubbed, flash_attn made optional with HAS_FLASH_ATTN flag.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-27 17:31:22 +01:00
parent b60ff4111b
commit 84c81e0e55
16 changed files with 7923 additions and 0 deletions
+9
View File
@@ -0,0 +1,9 @@
"""
PrismAudio model modules for inference.
Re-exports create_model_from_config from the factory module.
"""
from prismaudio_core.factory import create_model_from_config
__all__ = ["create_model_from_config"]
File diff suppressed because it is too large Load Diff
+830
View File
@@ -0,0 +1,830 @@
import torch
import math
import numpy as np
from torch import nn
from torch.nn import functional as F
from torchaudio import transforms as T
from alias_free_torch import Activation1d
from dac.nn.layers import WNConv1d, WNConvTranspose1d
from typing import Literal, Dict, Any
from .blocks import SnakeBeta
from .bottleneck import Bottleneck, DiscreteBottleneck
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
from .pretransforms import Pretransform
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
import torchaudio.transforms as T
import torch
if in_sr != target_sr:
resample_tf = T.Resample(in_sr, target_sr).to(device)
audio = resample_tf(audio)
if audio.shape[0] > target_channels:
audio = audio[:target_channels]
elif audio.shape[0] < target_channels:
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
if audio.shape[-1] < target_length:
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
elif audio.shape[-1] > target_length:
audio = audio[..., :target_length]
return audio.unsqueeze(0)
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config(pretransform, sample_rate)
def _lazy_create_bottleneck_from_config(bottleneck):
from prismaudio_core.factory import create_bottleneck_from_config
return create_bottleneck_from_config(bottleneck)
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
act = nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=7, dilation=dilation, padding=padding),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
WNConv1d(in_channels=out_channels, out_channels=out_channels,
kernel_size=1)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=in_channels,
out_channels=in_channels, dilation=9, use_snake=use_snake),
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
WNConv1d(in_channels=in_channels, out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride,
stride=1,
bias=False,
padding='same')
)
else:
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
self.layers = nn.Sequential(
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
upsample_layer,
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=1, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=3, use_snake=use_snake),
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
dilation=9, use_snake=use_snake),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
]
for i in range(self.depth-1):
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults = [1, 2, 4, 8],
strides = [2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
]
for i in range(self.depth-1, 0, -1):
layers += [DecoderBlock(
in_channels=c_mults[i]*channels,
out_channels=c_mults[i-1]*channels,
stride=strides[i-1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class DACEncoderWrapper(nn.Module):
def __init__(self, in_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Encoder as DACEncoder
latent_dim = kwargs.pop("latent_dim", None)
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
self.latent_dim = latent_dim
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
if in_channels != 1:
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
def forward(self, x):
x = self.encoder(x)
x = self.proj_out(x)
return x
class DACDecoderWrapper(nn.Module):
def __init__(self, latent_dim, out_channels=1, **kwargs):
super().__init__()
from dac.model.dac import Decoder as DACDecoder
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
self.latent_dim = latent_dim
def forward(self, x):
return self.decoder(x)
class AudioAutoencoder(nn.Module):
def __init__(
self,
encoder,
decoder,
latent_dim,
downsampling_ratio,
sample_rate,
io_channels=2,
bottleneck: Bottleneck = None,
pretransform: Pretransform = None,
in_channels = None,
out_channels = None,
soft_clip = False
):
super().__init__()
self.downsampling_ratio = downsampling_ratio
self.sample_rate = sample_rate
self.latent_dim = latent_dim
self.io_channels = io_channels
self.in_channels = io_channels
self.out_channels = io_channels
self.min_length = self.downsampling_ratio
if in_channels is not None:
self.in_channels = in_channels
if out_channels is not None:
self.out_channels = out_channels
self.bottleneck = bottleneck
self.encoder = encoder
self.decoder = decoder
self.pretransform = pretransform
self.soft_clip = soft_clip
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
info = {}
# import ipdb
# ipdb.set_trace()
if self.pretransform is not None and not skip_pretransform:
if self.pretransform.enable_grad:
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
else:
with torch.no_grad():
if iterate_batch:
audios = []
for i in range(audio.shape[0]):
audios.append(self.pretransform.encode(audio[i:i+1]))
audio = torch.cat(audios, dim=0)
else:
audio = self.pretransform.encode(audio)
if self.encoder is not None:
if iterate_batch:
latents = []
for i in range(audio.shape[0]):
latents.append(self.encoder(audio[i:i+1]))
latents = torch.cat(latents, dim=0)
else:
latents = self.encoder(audio)
else:
latents = audio
if self.bottleneck is not None:
# TODO: Add iterate batch logic, needs to merge the info dicts
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
info.update(bottleneck_info)
if return_info:
return latents, info
return latents
def decode(self, latents, iterate_batch=False, **kwargs):
if self.bottleneck is not None:
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.bottleneck.decode(latents[i:i+1]))
latents = torch.cat(decoded, dim=0)
else:
latents = self.bottleneck.decode(latents)
if iterate_batch:
decoded = []
for i in range(latents.shape[0]):
decoded.append(self.decoder(latents[i:i+1]))
decoded = torch.cat(decoded, dim=0)
else:
decoded = self.decoder(latents, **kwargs)
if self.pretransform is not None:
if self.pretransform.enable_grad:
if iterate_batch:
decodeds = []
for i in range(decoded.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
if iterate_batch:
decodeds = []
for i in range(latents.shape[0]):
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
decoded = torch.cat(decodeds, dim=0)
else:
decoded = self.pretransform.decode(decoded)
if self.soft_clip:
decoded = torch.tanh(decoded)
return decoded
def decode_tokens(self, tokens, **kwargs):
'''
Decode discrete tokens to audio
Only works with discrete autoencoders
'''
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
return self.decode(latents, **kwargs)
def preprocess_audio_for_encoder(self, audio, in_sr):
'''
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
If the model is mono, stereo audio will be converted to mono.
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
Audio will be resampled to the model's sample rate.
The output will have batch size 1 and be shape (1 x Channels x Length)
'''
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
'''
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
The audio in that list can be of different lengths and channels.
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
All audio will be resampled to the model's sample rate.
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
If the model is mono, all audio will be converted to mono.
The output will be a tensor of shape (Batch x Channels x Length)
'''
batch_size = len(audio_list)
if isinstance(in_sr_list, int):
in_sr_list = [in_sr_list]*batch_size
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
new_audio = []
max_length = 0
# resample & find the max length
for i in range(batch_size):
audio = audio_list[i]
in_sr = in_sr_list[i]
if len(audio.shape) == 3 and audio.shape[0] == 1:
# batchsize 1 was given by accident. Just squeeze it.
audio = audio.squeeze(0)
elif len(audio.shape) == 1:
# Mono signal, channel dimension is missing, unsqueeze it in
audio = audio.unsqueeze(0)
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
# Resample audio
if in_sr != self.sample_rate:
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
audio = resample_tf(audio)
new_audio.append(audio)
if audio.shape[-1] > max_length:
max_length = audio.shape[-1]
# Pad every audio to the same length, multiple of model's downsampling ratio
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
for i in range(batch_size):
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
# convert to tensor
return torch.stack(new_audio)
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
Overlap and chunk_size params are both measured in number of latents (not audio samples)
# and therefore you likely could use the same values with decode_audio.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Encode the entire audio in parallel
return self.encode(audio, **kwargs)
else:
# CHUNKED ENCODING
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
# import ipdb
# ipdb.set_trace()
samples_per_latent = self.downsampling_ratio
total_size = audio.shape[2] # in samples
print(f'audio shape: {audio.shape}')
batch_size = audio.shape[0]
chunk_size *= samples_per_latent # converting metric in latents to samples
overlap *= samples_per_latent # converting metric in latents to samples
hop_size = chunk_size - overlap
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = audio[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = audio[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# Note: y_size might be a different value from the latent length used in diffusion training
# because we can encode audio of varying lengths
# However, the audio should've been padded to a multiple of samples_per_latent by now.
y_size = total_size // samples_per_latent
# Create an empty latent, we will populate it with chunks as we encode them
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
print(f'y_final shape: {y_final.shape}')
for i in range(num_chunks):
x_chunk = chunks[i,:]
# encode the chunk
y_chunk = self.encode(x_chunk)
print(f'y_chunk shape: {y_chunk.shape}')
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size // samples_per_latent
t_end = t_start + chunk_size // samples_per_latent
# remove the edges of the overlaps
ol = overlap//samples_per_latent//2
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
'''
Decode latents to audio.
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
Every autoencoder will have a different receptive field size, and thus ideal overlap.
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
Smaller chunk_size uses less memory, but more compute.
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
'''
if not chunked:
# default behavior. Decode the entire latent in parallel
return self.decode(latents, **kwargs)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:,:,i:i+chunk_size]
chunks.append(chunk)
if i+chunk_size != total_size:
# Final chunk
chunk = latents[:,:,-chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = self.downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i,:]
# decode the chunk
y_chunk = self.decode(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks-1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap//2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks-1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
return y_final
class DiffusionAutoencoder(AudioAutoencoder):
def __init__(
self,
diffusion: ConditionedDiffusionModel,
diffusion_downsampling_ratio,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.diffusion = diffusion
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
if self.encoder is not None:
# Shrink the initial encoder parameters to avoid saturated latents
with torch.no_grad():
for param in self.encoder.parameters():
param *= 0.5
def decode(self, latents, steps=100):
upsampled_length = latents.shape[2] * self.downsampling_ratio
if self.bottleneck is not None:
latents = self.bottleneck.decode(latents)
if self.decoder is not None:
latents = self.decode(latents)
# Upsample latents to match diffusion length
if latents.shape[2] != upsampled_length:
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
from prismaudio_core.inference.sampling import sample
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
if self.pretransform is not None:
if self.pretransform.enable_grad:
decoded = self.pretransform.decode(decoded)
else:
with torch.no_grad():
decoded = self.pretransform.decode(decoded)
return decoded
# AE factories
def create_encoder_from_config(encoder_config: Dict[str, Any]):
encoder_type = encoder_config.get("type", None)
assert encoder_type is not None, "Encoder type must be specified"
if encoder_type == "oobleck":
encoder = OobleckEncoder(
**encoder_config["config"]
)
elif encoder_type == "seanet":
from encodec.modules import SEANetEncoder
seanet_encoder_config = encoder_config["config"]
#SEANet encoder expects strides in reverse order
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
encoder = SEANetEncoder(
**seanet_encoder_config
)
elif encoder_type == "dac":
dac_config = encoder_config["config"]
encoder = DACEncoderWrapper(**dac_config)
elif encoder_type == "local_attn":
from .local_attention import TransformerEncoder1D
local_attn_config = encoder_config["config"]
encoder = TransformerEncoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown encoder type {encoder_type}")
requires_grad = encoder_config.get("requires_grad", True)
if not requires_grad:
for param in encoder.parameters():
param.requires_grad = False
return encoder
def create_decoder_from_config(decoder_config: Dict[str, Any]):
decoder_type = decoder_config.get("type", None)
assert decoder_type is not None, "Decoder type must be specified"
if decoder_type == "oobleck":
decoder = OobleckDecoder(
**decoder_config["config"]
)
elif decoder_type == "seanet":
from encodec.modules import SEANetDecoder
decoder = SEANetDecoder(
**decoder_config["config"]
)
elif decoder_type == "dac":
dac_config = decoder_config["config"]
decoder = DACDecoderWrapper(**dac_config)
elif decoder_type == "local_attn":
from .local_attention import TransformerDecoder1D
local_attn_config = decoder_config["config"]
decoder = TransformerDecoder1D(
**local_attn_config
)
else:
raise ValueError(f"Unknown decoder type {decoder_type}")
requires_grad = decoder_config.get("requires_grad", True)
if not requires_grad:
for param in decoder.parameters():
param.requires_grad = False
return decoder
def create_autoencoder_from_config(config: Dict[str, Any]):
ae_config = config["model"]
encoder = create_encoder_from_config(ae_config["encoder"])
decoder = create_decoder_from_config(ae_config["decoder"])
bottleneck = ae_config.get("bottleneck", None)
latent_dim = ae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = ae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = ae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
in_channels = ae_config.get("in_channels", None)
out_channels = ae_config.get("out_channels", None)
pretransform = ae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
soft_clip = ae_config["decoder"].get("soft_clip", False)
return AudioAutoencoder(
encoder,
decoder,
io_channels=io_channels,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
bottleneck=bottleneck,
pretransform=pretransform,
in_channels=in_channels,
out_channels=out_channels,
soft_clip=soft_clip
)
def create_diffAE_from_config(config: Dict[str, Any]):
diffae_config = config["model"]
if "encoder" in diffae_config:
encoder = create_encoder_from_config(diffae_config["encoder"])
else:
encoder = None
if "decoder" in diffae_config:
decoder = create_decoder_from_config(diffae_config["decoder"])
else:
decoder = None
diffusion_model_type = diffae_config["diffusion"]["type"]
if diffusion_model_type == "DAU1d":
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "adp_1d":
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
elif diffusion_model_type == "dit":
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
latent_dim = diffae_config.get("latent_dim", None)
assert latent_dim is not None, "latent_dim must be specified in model config"
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
io_channels = diffae_config.get("io_channels", None)
assert io_channels is not None, "io_channels must be specified in model config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "sample_rate must be specified in model config"
bottleneck = diffae_config.get("bottleneck", None)
pretransform = diffae_config.get("pretransform", None)
if pretransform is not None:
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
if bottleneck is not None:
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
diffusion_downsampling_ratio = None,
if diffusion_model_type == "DAU1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
elif diffusion_model_type == "adp_1d":
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
elif diffusion_model_type == "dit":
diffusion_downsampling_ratio = 1
return DiffusionAutoencoder(
encoder=encoder,
decoder=decoder,
diffusion=diffusion,
io_channels=io_channels,
sample_rate=sample_rate,
latent_dim=latent_dim,
downsampling_ratio=downsampling_ratio,
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
bottleneck=bottleneck,
pretransform=pretransform
)
+339
View File
@@ -0,0 +1,339 @@
from functools import reduce
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.backends.cuda import sdp_kernel
from packaging import version
from dac.nn.layers import Snake1d
class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(ResidualBlock):
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
super().__init__([
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_mid),
Snake1d(c_mid) if use_snake else nn.GELU(),
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
], skip)
class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm = nn.GroupNorm(1, c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
if self.use_flash:
with sdp_kernel(*self.sdp_kernel_config):
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
else:
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)
def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
class Downsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv1d(x, weight, stride=2)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
class Upsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
def Downsample1d_2(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor,
padding=factor * (kernel_multiplier // 2),
)
def Upsample1d_2(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
)
else:
return nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor,
padding=factor // 2 + factor % 2,
output_padding=factor % 2,
)
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
#rms_norm = torch.compile(rms_norm)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
def normalize(x, eps=1e-4):
dim = list(range(1, x.ndim))
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
alpha = np.sqrt(n.numel() / x.numel())
return x / torch.add(eps, n, alpha=alpha)
class ForcedWNConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
def forward(self, x):
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(self.weight))
fan_in = self.weight[0].numel()
w = normalize(self.weight) / math.sqrt(fan_in)
return F.conv1d(x, w, padding='same')
# Kernels
use_compile = True
def compile(function, *args, **kwargs):
if not use_compile:
return function
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
@compile
def linear_geglu(x, weight, bias=None):
x = x @ weight.mT
if bias is not None:
x = x + bias
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
@compile
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
# Layers
class LinearGEGLU(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features * 2, bias=bias)
self.out_features = out_features
def forward(self, x):
return linear_geglu(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, shape, fix_scale = False, eps=1e-6):
super().__init__()
self.eps = eps
if fix_scale:
self.register_buffer("scale", torch.ones(shape))
else:
self.scale = nn.Parameter(torch.ones(shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# try:
# snake_beta = torch.compile(snake_beta)
# except RuntimeError:
# pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
+355
View File
@@ -0,0 +1,355 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from vector_quantize_pytorch import ResidualVQ, FSQ
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
class Bottleneck(nn.Module):
def __init__(self, is_discrete: bool = False):
super().__init__()
self.is_discrete = is_discrete
def encode(self, x, return_info=False, **kwargs):
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
class DiscreteBottleneck(Bottleneck):
def __init__(self, num_quantizers, codebook_size, tokens_id):
super().__init__(is_discrete=True)
self.num_quantizers = num_quantizers
self.codebook_size = codebook_size
self.tokens_id = tokens_id
def decode_tokens(self, codes, **kwargs):
raise NotImplementedError
class TanhBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
self.tanh = nn.Tanh()
def encode(self, x, return_info=False):
info = {}
x = torch.tanh(x)
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def vae_sample(mean, scale):
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean * mean + var - logvar - 1).sum(1).mean()
return latents, kl
class VAEBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False, **kwargs):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["kl"] = kl
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def compute_mean_kernel(x, y):
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
return torch.exp(-kernel_input).mean()
def compute_mmd(latents):
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
noise = torch.randn_like(latents_reshaped)
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
noise_kernel = compute_mean_kernel(noise, noise)
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
return mmd.mean()
class WassersteinBottleneck(Bottleneck):
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
super().__init__(is_discrete=False)
self.noise_augment_dim = noise_augment_dim
self.bypass_mmd = bypass_mmd
def encode(self, x, return_info=False):
info = {}
if self.training and return_info:
if self.bypass_mmd:
mmd = torch.tensor(0.0)
else:
mmd = compute_mmd(x)
info["mmd"] = mmd
if return_info:
return x, info
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
class L2Bottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self, x, return_info=False):
info = {}
x = F.normalize(x, dim=1)
if return_info:
return x, info
else:
return x
def decode(self, x):
return F.normalize(x, dim=1)
class RVQBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False, **kwargs):
info = {}
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class RVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
self.quantizer = ResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["num_quantizers"]
def encode(self, x, return_info=False):
info = {}
x, kl = vae_sample(*x.chunk(2, dim=1))
info["kl"] = kl
x = rearrange(x, "b c n -> b n c")
x, indices, loss = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
info["quantizer_indices"] = indices
info["quantizer_loss"] = loss.mean()
if return_info:
return x, info
else:
return x
def decode(self, x):
return x
def decode_tokens(self, codes, **kwargs):
latents = self.quantizer.get_outputs_from_indices(codes)
return self.decode(latents, **kwargs)
class DACRVQBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
self.noise_augment_dim = noise_augment_dim
def encode(self, x, return_info=False, **kwargs):
info = {}
info["pre_quantizer"] = x
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class DACRVQVAEBottleneck(DiscreteBottleneck):
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
self.quantizer = DACResidualVQ(**quantizer_kwargs)
self.num_quantizers = quantizer_kwargs["n_codebooks"]
self.quantize_on_decode = quantize_on_decode
def encode(self, x, return_info=False, n_quantizers: int = None):
info = {}
mean, scale = x.chunk(2, dim=1)
x, kl = vae_sample(mean, scale)
info["pre_quantizer"] = x
info["kl"] = kl
if self.quantize_on_decode:
return x, info if return_info else x
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
output = {
"z": z,
"codes": codes,
"latents": latents,
"vq/commitment_loss": commitment_loss,
"vq/codebook_loss": codebook_loss,
}
output["vq/commitment_loss"] /= self.num_quantizers
output["vq/codebook_loss"] /= self.num_quantizers
info.update(output)
if return_info:
return output["z"], info
return output["z"]
def decode(self, x):
if self.quantize_on_decode:
x = self.quantizer(x)[0]
return x
def decode_tokens(self, codes, **kwargs):
latents, _, _ = self.quantizer.from_codes(codes)
return self.decode(latents, **kwargs)
class FSQBottleneck(DiscreteBottleneck):
def __init__(self, noise_augment_dim=0, **kwargs):
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
self.noise_augment_dim = noise_augment_dim
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
def encode(self, x, return_info=False):
info = {}
orig_dtype = x.dtype
x = x.float()
x = rearrange(x, "b c n -> b n c")
x, indices = self.quantizer(x)
x = rearrange(x, "b n c -> b c n")
x = x.to(orig_dtype)
# Reorder indices to match the expected format
indices = rearrange(indices, "b n q -> b q n")
info["quantizer_indices"] = indices
if return_info:
return x, info
else:
return x
def decode(self, x):
if self.noise_augment_dim > 0:
noise = torch.randn(x.shape[0], self.noise_augment_dim,
x.shape[-1]).type_as(x)
x = torch.cat([x, noise], dim=1)
return x
def decode_tokens(self, tokens, **kwargs):
latents = self.quantizer.indices_to_codes(tokens)
return self.decode(latents, **kwargs)
File diff suppressed because it is too large Load Diff
+965
View File
@@ -0,0 +1,965 @@
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial
import numpy as np
import typing as tp
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
from .conditioners import MultiConditioner
from .dit import DiffusionTransformer
from .pretransforms import Pretransform
from .adp import UNetCFG1d, UNet1d
# Lazy imports for factory functions to avoid circular imports
def _get_create_pretransform_from_config():
from prismaudio_core.factory import create_pretransform_from_config
return create_pretransform_from_config
def _get_create_multi_conditioner_from_conditioning_config():
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
return create_multi_conditioner_from_conditioning_config
from time import time
class Profiler:
def __init__(self):
self.ticks = [[time(), None]]
def tick(self, msg):
self.ticks.append([time(), msg])
def __repr__(self):
rep = 80 * "=" + "\n"
for i in range(1, len(self.ticks)):
msg = self.ticks[i][1]
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
rep += msg + f": {ellapsed*1000:.2f}ms\n"
rep += 80 * "=" + "\n\n\n"
return rep
class DiffusionModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, t, **kwargs):
raise NotImplementedError()
class DiffusionModelWrapper(nn.Module):
def __init__(
self,
model: DiffusionModel,
io_channels,
sample_size,
sample_rate,
min_input_length,
pretransform: tp.Optional[Pretransform] = None,
):
super().__init__()
self.io_channels = io_channels
self.sample_size = sample_size
self.sample_rate = sample_rate
self.min_input_length = min_input_length
self.model = model
if pretransform is not None:
self.pretransform = pretransform
else:
self.pretransform = None
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class ConditionedDiffusionModel(nn.Module):
def __init__(self,
*args,
supports_cross_attention: bool = False,
supports_input_concat: bool = False,
supports_global_cond: bool = False,
supports_prepend_cond: bool = False,
**kwargs):
super().__init__(*args, **kwargs)
self.supports_cross_attention = supports_cross_attention
self.supports_input_concat = supports_input_concat
self.supports_global_cond = supports_global_cond
self.supports_prepend_cond = supports_prepend_cond
def forward(self,
x: torch.Tensor,
t: torch.Tensor,
cross_attn_cond: torch.Tensor = None,
cross_attn_mask: torch.Tensor = None,
input_concat_cond: torch.Tensor = None,
global_embed: torch.Tensor = None,
prepend_cond: torch.Tensor = None,
prepend_cond_mask: torch.Tensor = None,
cfg_scale: float = 1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
**kwargs):
raise NotImplementedError()
class ConditionedDiffusionModelWrapper(nn.Module):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model: ConditionedDiffusionModel,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
zero_init: bool = False,
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
sync_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.sync_cond_ids = sync_cond_ids
self.min_input_length = min_input_length
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
if zero_init is True:
self.conditioner.apply(_basic_init)
self.model.model.initialize_weights()
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
sync_input = None
if len(self.cross_attn_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
cross_attention_input = []
cross_attention_masks = []
for key in self.cross_attn_cond_ids:
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
# Add sequence dimension if it's not there
if len(cross_attn_in.shape) == 2:
cross_attn_in = cross_attn_in.unsqueeze(1)
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
cross_attention_input.append(cross_attn_in)
cross_attention_masks.append(cross_attn_mask)
# import ipdb
# ipdb.set_trace()
cross_attention_input = torch.cat(cross_attention_input, dim=1)
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
if len(self.add_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
add_input = []
for key in self.add_cond_ids:
add_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(add_in.shape) == 2:
add_in = add_in.unsqueeze(1)
# add_in = add_in.transpose(1,2)
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
# add_in = add_in.transpose(1,2)
add_input.append(add_in)
add_input = torch.cat(add_input, dim=2)
if len(self.sync_cond_ids) > 0:
# Concatenate all cross-attention inputs over the sequence dimension
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
sync_input = []
for key in self.sync_cond_ids:
sync_in = conditioning_tensors[key][0]
# Add sequence dimension if it's not there
if len(sync_in.shape) == 2:
sync_in = sync_in.unsqueeze(1)
sync_input.append(sync_in)
sync_input = torch.cat(sync_input, dim=2)
if len(self.global_cond_ids) > 0:
# Concatenate all global conditioning inputs over the channel dimension
# Assumes that the global conditioning inputs are of shape (batch, channels)
global_conds = []
for key in self.global_cond_ids:
global_cond_input = conditioning_tensors[key][0]
if len(global_cond_input.shape) == 2:
global_cond_input = global_cond_input.unsqueeze(1)
global_conds.append(global_cond_input)
# # Concatenate over the channel dimension
# if global_conds[0].shape[-1] == 768:
# global_cond = torch.cat(global_conds, dim=-1)
# else:
# global_cond = sum(global_conds)
global_cond = sum(global_conds)
# global_cond = torch.cat(global_conds, dim=-1)
if len(global_cond.shape) == 3:
global_cond = global_cond.squeeze(1)
if len(self.input_concat_ids) > 0:
# Concatenate all input concat conditioning inputs over the channel dimension
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
if len(self.prepend_cond_ids) > 0:
# Concatenate all prepend conditioning inputs over the sequence dimension
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
prepend_conds = []
prepend_cond_masks = []
for key in self.prepend_cond_ids:
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
if len(prepend_cond_input.shape) == 2:
prepend_cond_input = prepend_cond_input.unsqueeze(1)
prepend_conds.append(prepend_cond_input)
prepend_cond_masks.append(prepend_cond_mask)
prepend_cond = torch.cat(prepend_conds, dim=1)
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
if negative:
return {
"negative_cross_attn_cond": cross_attention_input,
"negative_cross_attn_mask": cross_attention_masks,
"negative_global_cond": global_cond,
"negative_input_concat_cond": input_concat_cond
}
else:
return {
"cross_attn_cond": cross_attention_input,
"cross_attn_mask": cross_attention_masks,
"global_cond": global_cond,
"input_concat_cond": input_concat_cond,
"prepend_cond": prepend_cond,
"prepend_cond_mask": prepend_cond_mask,
"add_cond": add_input,
"sync_cond": sync_input
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class UNetCFG1DWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
self.model = UNetCFG1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
input_concat_cond=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
**kwargs):
p = Profiler()
p.tick("start")
channels_list = None
if input_concat_cond is not None:
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
embedding=cross_attn_cond,
embedding_mask=cross_attn_mask,
features=global_cond,
channels_list=channels_list,
embedding_scale=cfg_scale,
embedding_mask_proba=cfg_dropout_prob,
batch_cfg=batch_cfg,
rescale_cfg=rescale_cfg,
negative_embedding=negative_cross_attn_cond,
negative_embedding_mask=negative_cross_attn_mask,
**kwargs)
p.tick("UNetCFG1D forward")
#print(f"Profiler: {p}")
return outputs
class UNet1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
self.model = UNet1d(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
global_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
**kwargs):
channels_list = None
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
channels_list = [input_concat_cond]
outputs = self.model(
x,
t,
features=global_cond,
channels_list=channels_list,
**kwargs)
return outputs
class UNet1DUncondWrapper(DiffusionModel):
def __init__(
self,
in_channels,
*args,
**kwargs
):
super().__init__()
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
self.io_channels = in_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
class DAU1DCondWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
self.model = DiffusionAttnUnet1D(*args, **kwargs)
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self,
x,
t,
input_concat_cond=None,
cross_attn_cond=None,
cross_attn_mask=None,
global_cond=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
negative_global_cond=None,
negative_input_concat_cond=None,
prepend_cond=None,
**kwargs):
return self.model(x, t, cond = input_concat_cond)
class DiffusionAttnUnet1D(nn.Module):
def __init__(
self,
io_channels = 2,
depth=14,
n_attn_layers = 6,
channels = [128, 128, 256, 256] + [512] * 10,
cond_dim = 0,
cond_noise_aug = False,
kernel_size = 5,
learned_resample = False,
strides = [2] * 13,
conv_bias = True,
use_snake = False
):
super().__init__()
self.cond_noise_aug = cond_noise_aug
self.io_channels = io_channels
if self.cond_noise_aug:
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
self.timestep_embed = FourierFeatures(1, 16)
attn_layer = depth - n_attn_layers
strides = [1] + strides
block = nn.Identity()
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
for i in range(depth, 0, -1):
c = channels[i - 1]
stride = strides[i-1]
if stride > 2 and not learned_resample:
raise ValueError("Must have stride 2 without learned resampling")
if i > 1:
c_prev = channels[i - 2]
add_attn = i >= attn_layer and n_attn_layers > 0
block = SkipBlock(
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
conv_block(c_prev, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
block,
conv_block(c * 2 if i != depth else c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c),
SelfAttention1d(
c, c // 32) if add_attn else nn.Identity(),
conv_block(c, c, c_prev),
SelfAttention1d(c_prev, c_prev //
32) if add_attn else nn.Identity(),
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
)
else:
cond_embed_dim = 16 if not self.cond_noise_aug else 32
block = nn.Sequential(
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
conv_block(c, c, c),
conv_block(c, c, c),
block,
conv_block(c * 2, c, c),
conv_block(c, c, c),
conv_block(c, c, io_channels, is_last=True),
)
self.net = block
with torch.no_grad():
for param in self.net.parameters():
param *= 0.5
def forward(self, x, t, cond=None, cond_aug_scale=None):
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
inputs = [x, timestep_embed]
if cond is not None:
if cond.shape[2] != x.shape[2]:
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
if self.cond_noise_aug:
# Get a random number between 0 and 1, uniformly sampled
if cond_aug_scale is None:
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
else:
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
# Add noise to the conditioning signal
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
# Get embedding for noise cond level, reusing timestamp_embed
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
inputs.append(aug_level_embed)
inputs.append(cond)
outputs = self.net(torch.cat(inputs, dim=1))
return outputs
class DiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = DiffusionTransformer(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
cross_attn_cond=None,
cross_attn_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
negative_input_concat_cond=None,
global_cond=None,
negative_global_cond=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_mask,
negative_cross_attn_cond=negative_cross_attn_cond,
negative_cross_attn_mask=negative_cross_attn_mask,
input_concat_cond=input_concat_cond,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
global_embed=global_cond,
**kwargs)
class MMDiTWrapper(ConditionedDiffusionModel):
def __init__(
self,
*args,
**kwargs
):
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
self.model = MMAudio(*args, **kwargs)
# with torch.no_grad():
# for param in self.model.parameters():
# param *= 0.5
def forward(self,
x,
t,
clip_f,
sync_f,
text_f,
inpaint_masked_input=None,
t5_features=None,
metaclip_global_text_features=None,
cfg_scale=1.0,
cfg_dropout_prob: float = 0.0,
batch_cfg: bool = True,
rescale_cfg: bool = False,
scale_phi: float = 0.0,
**kwargs):
# breakpoint()
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
return self.model(
latent=x,
t=t,
clip_f=clip_f,
sync_f=sync_f,
text_f=text_f,
inpaint_masked_input=inpaint_masked_input,
t5_features=t5_features,
metaclip_global_text_features=metaclip_global_text_features,
cfg_scale=cfg_scale,
cfg_dropout_prob=cfg_dropout_prob,
scale_phi=scale_phi,
**kwargs)
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
"""
A diffusion model that takes in conditioning
"""
def __init__(
self,
model,
conditioner: MultiConditioner,
io_channels,
sample_rate,
min_input_length: int,
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
pretransform: tp.Optional[Pretransform] = None,
cross_attn_cond_ids: tp.List[str] = [],
global_cond_ids: tp.List[str] = [],
input_concat_ids: tp.List[str] = [],
prepend_cond_ids: tp.List[str] = [],
add_cond_ids: tp.List[str] = [],
mm_cond_ids: tp.List[str] = [],
):
super().__init__()
self.model = model
self.conditioner = conditioner
self.io_channels = io_channels
self.sample_rate = sample_rate
self.diffusion_objective = diffusion_objective
self.pretransform = pretransform
self.cross_attn_cond_ids = cross_attn_cond_ids
self.global_cond_ids = global_cond_ids
self.input_concat_ids = input_concat_ids
self.prepend_cond_ids = prepend_cond_ids
self.add_cond_ids = add_cond_ids
self.min_input_length = min_input_length
self.mm_cond_ids = mm_cond_ids
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
cross_attention_input = None
cross_attention_masks = None
global_cond = None
input_concat_cond = None
prepend_cond = None
prepend_cond_mask = None
add_input = None
inpaint_masked_input = None
t5_features = None
metaclip_global_text_features = None
clip_f = conditioning_tensors["metaclip_features"]
sync_f = conditioning_tensors["sync_features"]
text_f = conditioning_tensors["metaclip_text_features"]
if 'inpaint_masked_input' in conditioning_tensors.keys():
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
if 't5_features' in conditioning_tensors.keys():
t5_features = conditioning_tensors["t5_features"]
if 'metaclip_global_text_features' in conditioning_tensors.keys():
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
return {
"clip_f": clip_f,
"sync_f": sync_f,
"text_f": text_f,
"inpaint_masked_input": inpaint_masked_input,
"t5_features": t5_features,
"metaclip_global_text_features": metaclip_global_text_features
}
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
# breakpoint()
# print(kwargs)
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
def generate(self, *args, **kwargs):
from prismaudio_core.inference.generation import generate_diffusion_cond
return generate_diffusion_cond(self, *args, **kwargs)
class DiTUncondWrapper(DiffusionModel):
def __init__(
self,
io_channels,
*args,
**kwargs
):
super().__init__()
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
self.io_channels = io_channels
with torch.no_grad():
for param in self.model.parameters():
param *= 0.5
def forward(self, x, t, **kwargs):
return self.model(x, t, **kwargs)
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
model_type = diffusion_uncond_config.get('type', None)
diffusion_config = diffusion_uncond_config.get('config', {})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**diffusion_config
)
elif model_type == "adp_uncond_1d":
model = UNet1DUncondWrapper(
**diffusion_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**diffusion_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
diffusion_uncond_config = config["model"]
diffusion_config = diffusion_uncond_config.get('diffusion', {})
model_type = diffusion_config.get('type', None)
model_config = diffusion_config.get("config",{})
assert model_type is not None, "Must specify model type in config"
pretransform = diffusion_uncond_config.get("pretransform", None)
sample_size = config.get("sample_size", None)
assert sample_size is not None, "Must specify sample size in config"
sample_rate = config.get("sample_rate", None)
assert sample_rate is not None, "Must specify sample rate in config"
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if model_type == 'DAU1d':
model = DiffusionAttnUnet1D(
**model_config
)
elif model_type == "adp_uncond_1d":
model = UNet1DUncondWrapper(
io_channels = io_channels,
**model_config
)
elif model_type == "dit":
model = DiTUncondWrapper(
**model_config
)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
return DiffusionModelWrapper(model,
io_channels=model.io_channels,
sample_size=sample_size,
sample_rate=sample_rate,
pretransform=pretransform,
min_input_length=min_input_length)
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
model_config = config["model"]
model_type = config["model_type"]
diffusion_config = model_config.get('diffusion', None)
assert diffusion_config is not None, "Must specify diffusion config"
diffusion_model_type = diffusion_config.get('type', None)
assert diffusion_model_type is not None, "Must specify diffusion model type"
diffusion_model_config = diffusion_config.get('config', None)
assert diffusion_model_config is not None, "Must specify diffusion model config"
if diffusion_model_type == 'adp_cfg_1d':
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
elif diffusion_model_type == 'adp_1d':
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
elif diffusion_model_type == 'dit':
diffusion_model = DiTWrapper(**diffusion_model_config)
elif diffusion_model_type == 'mmdit':
diffusion_model = MMDiTWrapper(**diffusion_model_config)
io_channels = model_config.get('io_channels', None)
assert io_channels is not None, "Must specify io_channels in model config"
sample_rate = config.get('sample_rate', None)
assert sample_rate is not None, "Must specify sample_rate in config"
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
conditioning_config = model_config.get('conditioning', None)
conditioner = None
if conditioning_config is not None:
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
add_cond_ids = diffusion_config.get('add_cond_ids', [])
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
zero_init = diffusion_config.get('zero_init', False)
pretransform = model_config.get("pretransform", None)
if pretransform is not None:
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
min_input_length = pretransform.downsampling_ratio
else:
min_input_length = 1
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
min_input_length *= np.prod(diffusion_model_config["factors"])
elif diffusion_model_type == "dit":
min_input_length *= diffusion_model.model.patch_size
# Get the proper wrapper class
extra_kwargs = {}
if model_type == "mm_diffusion_cond":
wrapper_fn = MMConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
extra_kwargs["mm_cond_ids"] = mm_cond_ids
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
wrapper_fn = ConditionedDiffusionModelWrapper
extra_kwargs["diffusion_objective"] = diffusion_objective
elif model_type == "diffusion_prior":
prior_type = model_config.get("prior_type", None)
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
if prior_type == "mono_stereo":
from .diffusion_prior import MonoToStereoDiffusionPrior
wrapper_fn = MonoToStereoDiffusionPrior
return wrapper_fn(
diffusion_model,
conditioner,
min_input_length=min_input_length,
sample_rate=sample_rate,
cross_attn_cond_ids=cross_attention_ids,
global_cond_ids=global_cond_ids,
input_concat_ids=input_concat_ids,
prepend_cond_ids=prepend_cond_ids,
add_cond_ids=add_cond_ids,
sync_cond_ids=sync_cond_ids,
pretransform=pretransform,
io_channels=io_channels,
zero_init=zero_init,
**extra_kwargs
)
+541
View File
@@ -0,0 +1,541 @@
import typing as tp
import math
import torch
# from beartype.typing import Tuple
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from .utils import mask_from_frac_lengths, resample
class DiffusionTransformer(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
cond_ctx_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
add_token_dim=0,
sync_token_dim=0,
use_mlp=False,
use_zero_init=False,
**kwargs):
super().__init__()
self.cond_token_dim = cond_token_dim
# Timestep embeddings
timestep_features_dim = 256
# Timestep embeddings
self.timestep_cond_type = timestep_cond_type
self.timestep_features = FourierFeatures(1, timestep_features_dim)
if timestep_cond_type == "global":
timestep_embed_dim = embed_dim
elif timestep_cond_type == "input_concat":
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
input_concat_dim += timestep_embed_dim
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
self.use_mlp = use_mlp
if cond_token_dim > 0:
# Conditioning tokens
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
self.to_cond_embed = nn.Sequential(
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
)
else:
cond_embed_dim = 0
if global_cond_dim > 0:
# Global conditioning
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
self.to_global_embed = nn.Sequential(
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
)
if add_token_dim > 0:
# Conditioning tokens
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
self.to_add_embed = nn.Sequential(
nn.Linear(add_token_dim, add_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
)
else:
add_embed_dim = 0
if sync_token_dim > 0:
# Conditioning tokens
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
self.to_sync_embed = nn.Sequential(
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
)
else:
sync_embed_dim = 0
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
self.input_concat_dim = input_concat_dim
dim_in = io_channels + self.input_concat_dim
self.patch_size = patch_size
# Transformer
self.transformer_type = transformer_type
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
self.global_cond_type = global_cond_type
if self.transformer_type == "continuous_transformer":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend = cond_token_dim > 0,
cond_token_dim = cond_embed_dim,
global_cond_dim=global_dim,
**kwargs
)
else:
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
nn.init.zeros_(self.preprocess_conv.weight)
nn.init.zeros_(self.postprocess_conv.weight)
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
# if isinstance(module, nn.Conv1d):
# if module.bias is not None:
# nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
# Zero-out output layers:
if self.global_cond_type == "adaLN":
for block in self.transformer.layers:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.empty_clip_feat, 0)
nn.init.constant_(self.empty_sync_feat, 0)
def _forward(
self,
x,
t,
mask=None,
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
add_masks=None,
sync_cond=None,
return_info=False,
**kwargs):
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# reshape from (b, n, c) to (b, c, n)
if input_concat_cond.shape[1] != x.shape[1]:
input_concat_cond = input_concat_cond.transpose(1,2)
# Interpolate input_concat_cond to the same length as x
# if input_concat_cond.shape[1] != x.shape[2]:
# input_concat_cond = input_concat_cond.transpose(1,2)
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
# input_concat_cond = input_concat_cond.transpose(1,2)
# if len(global_embed.shape) == 2:
# global_embed = global_embed.unsqueeze(1)
# global_embed = global_embed + input_concat_cond
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
# import ipdb
# ipdb.set_trace()
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if self.timestep_cond_type == "global":
if global_embed is not None:
if len(global_embed.shape) == 3:
timestep_embed = timestep_embed.unsqueeze(1)
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
elif self.timestep_cond_type == "input_concat":
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend" and global_embed is not None:
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
if len(global_embed.shape) == 2:
prepend_inputs = global_embed.unsqueeze(1)
else:
prepend_inputs = global_embed
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
if len(global_embed.shape) == 2:
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
else:
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
b, seq_len, c = x.shape
# 计算需要填充的数量
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
if pad_amount > 0:
# 在时间维度上进行填充
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if add_cond is not None:
# Interpolate add_cond to the same length as x
# if self.use_mlp:
add_cond = self.to_add_embed(add_cond)
if add_cond.shape[1] != x.shape[1]:
add_cond = add_cond.transpose(1,2)
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
add_cond = add_cond.transpose(1,2)
# add_cond = resample(add_cond, x)
if sync_cond is not None:
sync_cond = self.to_sync_embed(sync_cond)
if self.transformer_type == "continuous_transformer":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
# 移除之前添加的填充
if pad_amount > 0:
output = output[:, :, :seq_len]
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
add_cond=None,
sync_cond=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
bsz, a, b = x.shape
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
t = t.to(model_dtype)
if cross_attn_cond is not None:
cross_attn_cond = cross_attn_cond.to(model_dtype)
if negative_cross_attn_cond is not None:
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
if input_concat_cond is not None:
input_concat_cond = input_concat_cond.to(model_dtype)
if global_embed is not None:
global_embed = global_embed.to(model_dtype)
if negative_global_embed is not None:
negative_global_embed = negative_global_embed.to(model_dtype)
if prepend_cond is not None:
prepend_cond = prepend_cond.to(model_dtype)
if add_cond is not None:
add_cond = add_cond.to(model_dtype)
if sync_cond is not None:
sync_cond = sync_cond.to(model_dtype)
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if add_cond is not None:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
add_cond = torch.where(dropout_mask, null_embed, add_cond)
if sync_cond is not None:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None and global_embed.shape[0] == bsz:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
elif global_embed is not None:
batch_global_cond = global_embed
else:
batch_global_cond = None
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
elif input_concat_cond is not None:
batch_input_concat_cond = input_concat_cond
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
elif cross_attn_cond is not None:
batch_cond = cross_attn_cond
else:
batch_cond = None
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
elif prepend_cond is not None:
batch_prepend_cond = prepend_cond
else:
batch_prepend_cond = None
batch_add_cond = None
# Handle CFG for cross-attention conditioning
if add_cond is not None and add_cond.shape[0] == bsz:
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
elif add_cond is not None:
batch_add_cond = add_cond
else:
batch_add_cond = None
batch_sync_cond = None
# Handle CFG for cross-attention conditioning
if sync_cond is not None and sync_cond.shape[0] == bsz:
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
elif sync_cond is not None:
batch_sync_cond = sync_cond
else:
batch_sync_cond = None
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask = batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed = batch_global_cond,
prepend_cond = batch_prepend_cond,
prepend_cond_mask = batch_prepend_cond_mask,
add_cond = batch_add_cond,
sync_cond = batch_sync_cond,
return_info = return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
add_cond=add_cond,
sync_cond=sync_cond,
mask=mask,
return_info=return_info,
**kwargs
)
+278
View File
@@ -0,0 +1,278 @@
import torch
from einops import rearrange
from torch import nn
from .blocks import AdaRMSNorm
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
class ContinuousLocalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_in = None,
dim_out = None,
causal = False,
local_attn_window_size = 64,
heads = 8,
ff_mult = 2,
cond_dim = 0,
cross_attn_cond_dim = 0,
**kwargs
):
super().__init__()
dim_head = dim//heads
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
self.local_attn_window_size = local_attn_window_size
self.cond_dim = cond_dim
self.cross_attn_cond_dim = cross_attn_cond_dim
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
for _ in range(depth):
self.layers.append(nn.ModuleList([
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
Attention(
dim=dim,
dim_heads=dim_head,
causal=causal,
zero_init_output=True,
natten_kernel_size=local_attn_window_size,
),
Attention(
dim=dim,
dim_heads=dim_head,
dim_context = cross_attn_cond_dim,
zero_init_output=True
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
]))
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
x = checkpoint(self.project_in, x)
if prepend_cond is not None:
x = torch.cat([prepend_cond, x], dim=1)
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
residual = x
if cond is not None:
x = checkpoint(attn_norm, x, cond)
else:
x = checkpoint(attn_norm, x)
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
if cross_attn_cond is not None:
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
residual = x
if cond is not None:
x = checkpoint(ff_norm, x, cond)
else:
x = checkpoint(ff_norm, x)
x = checkpoint(ff, x) + residual
return checkpoint(self.project_out, x)
class TransformerDownsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim = 768,
depth = 3,
heads = 12,
downsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.downsample_ratio = downsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size=local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
def forward(self, x):
x = checkpoint(self.project_in, x)
# Compute
x = self.transformer(x)
# Trade sequence length for channels
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
# Project back to embed dim
x = checkpoint(self.project_down, x)
return x
class TransformerUpsampleBlock1D(nn.Module):
def __init__(
self,
in_channels,
embed_dim,
depth = 3,
heads = 12,
upsample_ratio = 2,
local_attn_window_size = 64,
**kwargs
):
super().__init__()
self.upsample_ratio = upsample_ratio
self.transformer = ContinuousLocalTransformer(
dim=embed_dim,
depth=depth,
heads=heads,
local_attn_window_size = local_attn_window_size,
**kwargs
)
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
def forward(self, x):
# Project to embed dim
x = checkpoint(self.project_in, x)
# Project to increase channel dim
x = checkpoint(self.project_up, x)
# Trade channels for sequence length
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
# Compute
x = self.transformer(x)
return x
class TransformerEncoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [96, 192, 384, 768],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerDownsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
downsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
class TransformerDecoder1D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
embed_dims = [768, 384, 192, 96],
heads = [12, 12, 12, 12],
depths = [3, 3, 3, 3],
ratios = [2, 2, 2, 2],
local_attn_window_size = 64,
**kwargs
):
super().__init__()
layers = []
for layer in range(len(depths)):
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
layers.append(
TransformerUpsampleBlock1D(
in_channels = prev_dim,
embed_dim = embed_dims[layer],
heads = heads[layer],
depth = depths[layer],
upsample_ratio = ratios[layer],
local_attn_window_size = local_attn_window_size,
**kwargs
)
)
self.layers = nn.Sequential(*layers)
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
def forward(self, x):
x = rearrange(x, "b c n -> b n c")
x = checkpoint(self.project_in, x)
x = self.layers(x)
x = checkpoint(self.project_out, x)
x = rearrange(x, "b n c -> b c n")
return x
@@ -0,0 +1 @@
# mmmodules package
@@ -0,0 +1 @@
# mmmodules.model package
@@ -0,0 +1,95 @@
import torch
from torch import nn
from torch.nn import functional as F
class ChannelLastConv1d(nn.Conv1d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 1)
x = super().forward(x)
x = x.permute(0, 2, 1)
return x
# https://github.com/Stability-AI/sd3-ref
class MLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class ConvMLP(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int = 256,
kernel_size: int = 3,
padding: int = 1,
):
"""
Initialize the FeedForward module.
Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
Attributes:
w1 (ColumnParallelLinear): Linear transformation for the first layer.
w2 (RowParallelLinear): Linear transformation for the second layer.
w3 (ColumnParallelLinear): Linear transformation for the third layer.
"""
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w2 = ChannelLastConv1d(hidden_dim,
dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
self.w3 = ChannelLastConv1d(dim,
hidden_dim,
bias=False,
kernel_size=kernel_size,
padding=padding)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
+393
View File
@@ -0,0 +1,393 @@
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from scipy.optimize import fmin
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
class PQMF(nn.Module):
"""
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
Uses polyphase representation which is computationally more efficient for real-time.
Parameters:
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
"""
def __init__(self, attenuation, num_bands):
super(PQMF, self).__init__()
# Ensure num_bands is a power of 2
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
assert is_power_of_2, "'num_bands' must be a power of 2."
# Create the prototype filter
prototype_filter = design_prototype_filter(attenuation, num_bands)
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
# Register filters and settings
self.register_buffer("filter_bank", padded_filter_bank)
self.register_buffer("prototype", prototype_filter)
self.num_bands = num_bands
def forward(self, signal):
"""Decompose the signal into multiple frequency bands."""
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
signal = prepare_signal_dimensions(signal)
# The signal length must be a multiple of num_bands. Pad it with zeros.
signal = pad_signal(signal, self.num_bands)
# run it
signal = polyphase_analysis(signal, self.filter_bank)
return apply_alias_cancellation(signal)
def inverse(self, bands):
"""Reconstruct the original signal from the frequency bands."""
bands = apply_alias_cancellation(bands)
return polyphase_synthesis(bands, self.filter_bank)
def prepare_signal_dimensions(signal):
"""
Rearrange signal into Batch x Channels x Length.
Parameters
----------
signal : torch.Tensor or numpy.ndarray
The input signal.
Returns
-------
torch.Tensor
Preprocessed signal tensor.
"""
# Convert numpy to torch tensor
if isinstance(signal, np.ndarray):
signal = torch.from_numpy(signal)
# Ensure tensor
if not isinstance(signal, torch.Tensor):
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
# Modify dimension of signal to Batch x Channels x Length
if signal.dim() == 1:
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
signal = signal.unsqueeze(0).unsqueeze(0)
elif signal.dim() == 2:
# This is a multi-channel signal (e.g. stereo)
# Rearrange so that larger dimension (Length) is last
if signal.shape[0] > signal.shape[1]:
signal = signal.T
# Unsqueeze to 1 x Channels x Length
signal = signal.unsqueeze(0)
return signal
def pad_signal(signal, num_bands):
"""
Pads the signal to make its length divisible by the given number of bands.
Parameters
----------
signal : torch.Tensor
The input signal tensor, where the last dimension represents the signal length.
num_bands : int
The number of bands by which the signal length should be divisible.
Returns
-------
torch.Tensor
The padded signal tensor. If the original signal length was already divisible
by num_bands, returns the original signal unchanged.
"""
remainder = signal.shape[-1] % num_bands
if remainder > 0:
padding_size = num_bands - remainder
signal = nn.functional.pad(signal, (0, padding_size))
return signal
def generate_modulated_filter_bank(prototype_filter, num_bands):
"""
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
Parameters
----------
prototype_filter : torch.Tensor
The prototype filter used as the basis for modulation.
num_bands : int
The number of desired subbands or filters.
Returns
-------
torch.Tensor
A bank of cosine modulated filters.
"""
# Initialize indices for modulation.
subband_indices = torch.arange(num_bands).reshape(-1, 1)
# Calculate the length of the prototype filter.
filter_length = prototype_filter.shape[-1]
# Generate symmetric time indices centered around zero.
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
# Calculate phase offsets to ensure orthogonality between subbands.
phase_offsets = (-1)**subband_indices * np.pi / 4
# Compute the cosine modulation function.
modulation = torch.cos(
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
)
# Apply modulation to the prototype filter.
modulated_filters = 2 * prototype_filter * modulation
return modulated_filters
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
"""
Design a lowpass filter using the Kaiser window.
Parameters
----------
angular_cutoff : float
The angular frequency cutoff of the filter.
attenuation : float
The desired stopband attenuation in decibels (dB).
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The designed lowpass filter coefficients.
"""
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
# Ensure the estimated length is odd.
estimated_length = 2 * (estimated_length // 2) + 1
if filter_length is None:
filter_length = estimated_length
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
"""
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
Parameters
----------
angular_cutoff : float
Angular frequency cutoff of the filter.
attenuation : float
Desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter.
Returns
-------
float
The computed objective (loss) value for the given filter specs.
"""
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
def design_prototype_filter(attenuation, num_bands, filter_length=None):
"""
Design the optimal prototype filter for a multiband system given the desired specs.
Parameters
----------
attenuation : float
The desired stopband attenuation in dB.
num_bands : int
Number of bands for the multiband filter system.
filter_length : int, optional
Desired length of the filter. If not provided, it's computed based on the given specs.
Returns
-------
ndarray
The optimal prototype filter coefficients.
"""
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
1 / num_bands, disp=0)[0]
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
return torch.tensor(prototype_filter, dtype=torch.float32)
def pad_to_nearest_power_of_two(x):
"""
Pads the input tensor 'x' on both sides such that its last dimension
becomes the nearest larger power of two.
Parameters:
-----------
x : torch.Tensor
The input tensor to be padded.
Returns:
--------
torch.Tensor
The padded tensor.
"""
current_length = x.shape[-1]
target_length = 2**math.ceil(math.log2(current_length))
total_padding = target_length - current_length
left_padding = total_padding // 2
right_padding = total_padding - left_padding
return nn.functional.pad(x, (left_padding, right_padding))
def apply_alias_cancellation(x):
"""
Applies alias cancellation by inverting the sign of every
second element of every second row, starting from the second
row's first element in a tensor.
This operation helps ensure that the aliasing introduced in
each band during the decomposition will be counteracted during
the reconstruction.
Parameters:
-----------
x : torch.Tensor
The input tensor.
Returns:
--------
torch.Tensor
Tensor with specific elements' sign inverted for alias cancellation.
"""
# Create a mask of the same shape as 'x', initialized with all ones
mask = torch.ones_like(x)
# Update specific elements in the mask to -1 to perform inversion
mask[..., 1::2, ::2] = -1
# Apply the mask to the input tensor 'x'
return x * mask
def ensure_odd_length(tensor):
"""
Pads the last dimension of a tensor to ensure its size is odd.
Parameters:
-----------
tensor : torch.Tensor
Input tensor whose last dimension might need padding.
Returns:
--------
torch.Tensor
The original tensor if its last dimension was already odd,
or the padded tensor with an odd-sized last dimension.
"""
last_dim_size = tensor.shape[-1]
if last_dim_size % 2 == 0:
tensor = nn.functional.pad(tensor, (0, 1))
return tensor
def polyphase_analysis(signal, filter_bank):
"""
Applies the polyphase method to efficiently analyze the signal using a filter bank.
Parameters:
-----------
signal : torch.Tensor
Input signal tensor with shape (Batch x Channels x Length).
filter_bank : torch.Tensor
Filter bank tensor with shape (Bands x Length).
Returns:
--------
torch.Tensor
Signal split into sub-bands. (Batch x Channels x Bands x Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange signal for polyphase processing.
# Also combine Batch x Channel into one dimension for now.
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
# Rearrange the filter bank for matching signal shape
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
# Apply convolution with appropriate padding to maintain spatial dimensions
padding = filter_bank.shape[-1] // 2
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
# Truncate the last dimension post-convolution to adjust the output shape
filtered_signal = filtered_signal[..., :-1]
# Rearrange the first dimension back into Batch x Channels
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
return filtered_signal
def polyphase_synthesis(signal, filter_bank):
"""
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
Parameters
----------
signal : torch.Tensor
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
filter_bank : torch.Tensor
Analysis filter bank (shape: Bands x Length).
should_rearrange : bool, optional
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
Returns
-------
torch.Tensor
Reconstructed signal (shape: Batch x Channels X Length)
"""
num_bands = filter_bank.shape[0]
num_channels = signal.shape[1]
# Rearrange the filter bank
filter_bank = filter_bank.flip(-1)
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
# Combine Batch x Channels into one dimension for now.
signal = rearrange(signal, "b c n t -> (b c) n t")
# Apply convolution with appropriate padding
padding_amount = filter_bank.shape[-1] // 2 + 1
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
# Scale the result
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
# Reorganize the output and truncate
reconstructed_signal = reconstructed_signal.flip(1)
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
return reconstructed_signal
+258
View File
@@ -0,0 +1,258 @@
import torch
from einops import rearrange
from torch import nn
class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
self.is_discrete = is_discrete
self.io_channels = io_channels
self.encoded_channels = None
self.downsampling_ratio = None
self.enable_grad = enable_grad
def encode(self, x):
raise NotImplementedError
def decode(self, z):
raise NotImplementedError
def tokenize(self, x):
raise NotImplementedError
def decode_tokens(self, tokens):
raise NotImplementedError
class AutoencoderPretransform(Pretransform):
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
self.model = model
self.model.requires_grad_(False).eval()
self.scale=scale
self.downsampling_ratio = model.downsampling_ratio
self.io_channels = model.io_channels
self.sample_rate = model.sample_rate
self.model_half = model_half
self.iterate_batch = iterate_batch
self.encoded_channels = model.latent_dim
self.chunked = chunked
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
if self.model_half:
self.model.half()
def encode(self, x, **kwargs):
if self.model_half:
x = x.half()
self.model.to(torch.float16)
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
encoded = encoded.float()
return encoded / self.scale
def decode(self, z, **kwargs):
z = z * self.scale
if self.model_half:
z = z.half()
self.model.to(torch.float16)
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
if self.model_half:
decoded = decoded.float()
return decoded
def tokenize(self, x, **kwargs):
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
_, info = self.model.encode(x, return_info = True, **kwargs)
return info[self.model.bottleneck.tokens_id]
def decode_tokens(self, tokens, **kwargs):
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
return self.model.decode_tokens(tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.model.load_state_dict(state_dict, strict=strict)
class WaveletPretransform(Pretransform):
def __init__(self, channels, levels, wavelet):
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
from .wavelets import WaveletEncode1d, WaveletDecode1d
self.encoder = WaveletEncode1d(channels, levels, wavelet)
self.decoder = WaveletDecode1d(channels, levels, wavelet)
self.downsampling_ratio = 2 ** levels
self.io_channels = channels
self.encoded_channels = channels * self.downsampling_ratio
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
class PQMFPretransform(Pretransform):
def __init__(self, attenuation=100, num_bands=16):
# TODO: Fix PQMF to take in in-channels
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
from .pqmf import PQMF
self.pqmf = PQMF(attenuation, num_bands)
def encode(self, x):
# x is (Batch x Channels x Time)
x = self.pqmf.forward(x)
# pqmf.forward returns (Batch x Channels x Bands x Time)
# but Pretransform needs Batch x Channels x Time
# so concatenate channels and bands into one axis
return rearrange(x, "b c n t -> b (c n) t")
def decode(self, x):
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
# returns (Batch x Channels x Time)
return self.pqmf.inverse(x)
class PretrainedDACPretransform(Pretransform):
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
import dac
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
self.model = dac.DAC.load(model_path)
self.quantize_on_decode = quantize_on_decode
if model_type == "44khz":
self.downsampling_ratio = 512
else:
self.downsampling_ratio = 320
self.io_channels = 1
self.scale = scale
self.chunked = chunked
self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.n_codebooks
self.codebook_size = self.model.codebook_size
def encode(self, x):
latents = self.model.encoder(x)
if self.quantize_on_decode:
output = latents
else:
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
output = z
if self.scale != 1.0:
output = output / self.scale
return output
def decode(self, z):
if self.scale != 1.0:
z = z * self.scale
if self.quantize_on_decode:
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
return self.model.decode(z)
def tokenize(self, x):
return self.model.encode(x)[1]
def decode_tokens(self, tokens):
latents = self.model.quantizer.from_codes(tokens)
return self.model.decode(latents)
class AudiocraftCompressionPretransform(Pretransform):
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
try:
from audiocraft.models import CompressionModel
except ImportError:
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
self.model = CompressionModel.get_pretrained(model_type)
self.quantize_on_decode = quantize_on_decode
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
self.sample_rate = self.model.sample_rate
self.io_channels = self.model.channels
self.scale = scale
#self.encoded_channels = self.model.latent_dim
self.num_quantizers = self.model.num_codebooks
self.codebook_size = self.model.cardinality
self.model.to(torch.float16).eval().requires_grad_(False)
def encode(self, x):
assert False, "Audiocraft compression models do not support continuous encoding"
# latents = self.model.encoder(x)
# if self.quantize_on_decode:
# output = latents
# else:
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
# output = z
# if self.scale != 1.0:
# output = output / self.scale
# return output
def decode(self, z):
assert False, "Audiocraft compression models do not support continuous decoding"
# if self.scale != 1.0:
# z = z * self.scale
# if self.quantize_on_decode:
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
# return self.model.decode(z)
def tokenize(self, x):
with torch.cuda.amp.autocast(enabled=False):
return self.model.encode(x.to(torch.float16))[0]
def decode_tokens(self, tokens):
with torch.cuda.amp.autocast(enabled=False):
return self.model.decode(tokens)
+993
View File
@@ -0,0 +1,993 @@
from functools import reduce, partial
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
from typing import Callable, Literal
try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
flash_attn_kvpacked_func = None
flash_attn_func = None
from .utils import compile
try:
import natten
except ImportError:
natten = None
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return x * (1 + scale) + shift
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device = device)
return self.forward(t)
@autocast(enabled = False)
def forward(self, t):
device = self.inv_freq.device
t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim = -1)
# norms
class DynamicTanh(nn.Module):
def __init__(self, dim, init_alpha=10.0):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
x = F.tanh(self.alpha * x)
return self.gamma * x + self.beta
class RunningInstanceNorm(nn.Module):
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
super().__init__()
self.register_buffer("running_mean", torch.zeros(1,1,dim))
self.register_buffer("running_std", torch.ones(1,1,dim))
self.saturate = saturate
self.eps = eps
self.momentum = momentum
self.dim = dim
self.trainable_gain = trainable_gain
if self.trainable_gain:
self.gain = nn.Parameter(torch.ones(1))
def _update_stats(self, x):
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
def forward(self, x):
if self.training:
self._update_stats(x)
x = (x - self.running_mean) / self.running_std
if self.saturate:
x = torch.asinh(x)
if self.trainable_gain:
x = x * self.gain
return x
class LayerNorm(nn.Module):
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
if fix_scale:
self.register_buffer("gamma", torch.ones(dim))
else:
self.gamma = nn.Parameter(torch.ones(dim))
if bias:
self.beta = nn.Parameter(torch.zeros(dim))
else:
self.register_buffer("beta", torch.zeros(dim))
self.eps = eps
self.force_fp32 = force_fp32
def forward(self, x):
if not self.force_fp32:
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
else:
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
return output.to(x.dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_val = 1e-5):
super().__init__()
self.scale = nn.Parameter(torch.full([dim], init_val))
def forward(self, x):
return x * self.scale
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
use_conv = False,
conv_kernel_size = 3,
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim = -1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
no_bias = False,
glu = True,
use_conv = False,
conv_kernel_size = 3,
zero_init_output = True,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
# init last linear layer to 0
if zero_init_output:
nn.init.zeros_(linear_out.weight)
if not no_bias:
nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
differential = False,
feat_scale = False
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.differential = differential
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
if differential:
self.to_q = nn.Linear(dim, dim * 2, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
else:
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
else:
if differential:
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim, bias=False)
if zero_init_output:
nn.init.zeros_(self.to_out.weight)
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
elif self.qk_norm == 'rns':
self.q_norm = nn.RMSNorm(dim_heads)
self.k_norm = nn.RMSNorm(dim_heads)
elif self.qk_norm == 'dyt':
self.q_norm = DynamicTanh(dim_heads)
self.k_norm = DynamicTanh(dim_heads)
self.sdp_kwargs = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
self.feat_scale = feat_scale
if self.feat_scale:
self.lambda_dc = nn.Parameter(torch.zeros(dim))
self.lambda_hf = nn.Parameter(torch.zeros(dim))
self.causal = causal
if causal:
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
@compile
def apply_qk_layernorm(self, q, k):
q_type = q.dtype
k_type = k.dtype
q = self.q_norm(q).to(q_type)
k = self.k_norm(k).to(k_type)
return q, k
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
if self.num_heads != self.kv_heads:
# Repeat interleave kv_heads to match q_heads for grouped query attention
heads_per_kv_head = self.num_heads // self.kv_heads
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
flash_attn_available = HAS_FLASH_ATTN
if flash_attn_sliding_window is not None and (not flash_attn_available):
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
print(f"Disabling FlexAttention because causal is set")
flex_attention_block_mask = None
flex_attention_score_mod = None
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
out = flex_attention_compiled(q,k,v,
block_mask = flex_attention_block_mask,
score_mod = flex_attention_score_mod)
elif flash_attn_available:
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
else:
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
return out
#@compile
def forward(
self,
x,
context = None,
rotary_pos_emb = None,
causal = None,
flex_attention_block_mask = None,
flex_attention_score_mod = None,
flash_attn_sliding_window = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
if self.differential:
q, q_diff = self.to_q(x).chunk(2, dim=-1)
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
q = torch.stack([q, q_diff], dim = 1)
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
k = torch.stack([k, k_diff], dim = 1)
else:
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
if self.differential:
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
q = torch.stack([q, q_diff], dim = 1)
k = torch.stack([k, k_diff], dim = 1)
else:
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm != "none":
q, k = self.apply_qk_layernorm(q, k)
if rotary_pos_emb is not None:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
if q.shape[-2] >= k.shape[-2]:
ratio = q.shape[-2] / k.shape[-2]
q_freqs, k_freqs = freqs, ratio * freqs
else:
ratio = k.shape[-2] / q.shape[-2]
q_freqs, k_freqs = ratio * freqs, freqs
q = apply_rotary_pos_emb(q, q_freqs)
k = apply_rotary_pos_emb(k, k_freqs)
q = q.to(v.dtype)
k = k.to(v.dtype)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if self.differential:
q, q_diff = q.unbind(dim = 1)
k, k_diff = k.unbind(dim = 1)
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
out = out - out_diff
else:
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
# merge heads
out = rearrange(out, ' b h n d -> b n (h d)')
# Communicate between heads
# with autocast(enabled = False):
# out_dtype = out.dtype
# out = out.to(torch.float32)
# out = self.to_out(out).to(out_dtype)
out = self.to_out(out)
if self.feat_scale:
out_dc = out.mean(dim=-2, keepdim=True)
out_hf = out - out_dc
# Selectively modulate DC and high frequency components
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs = {},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
#@compile
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
cross_attend = False,
dim_context = None,
global_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
add_rope = False,
layer_scale = False,
use_sync_block_film = False,
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {}
):
super().__init__()
self.dim = dim
self.dim_heads = min(dim_heads,dim)
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
if layer_scale and zero_init_branch_outputs:
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
zero_init_branch_outputs = False
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.add_rope = add_rope
self.self_attn = Attention(
dim,
dim_heads = self.dim_heads,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.cross_attend = cross_attend
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.cross_attn = Attention(
dim,
dim_heads = self.dim_heads,
dim_context=dim_context,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.layer_ix = layer_ix
self.conformer = None
if conformer:
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
if use_sync_block_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
@compile
def forward(
self,
x,
context = None,
global_cond=None,
rotary_pos_emb = None,
self_attention_block_mask = None,
self_attention_score_mod = None,
cross_attention_block_mask = None,
cross_attention_score_mod = None,
self_attention_flash_sliding_window = None,
cross_attention_flash_sliding_window = None,
sync_cond = None,
prepend_length=0
):
if rotary_pos_emb is None and self.add_rope:
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
if len(global_cond.shape) == 2:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
else:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
x = x * torch.sigmoid(1 - gate_self)
x = self.self_attn_scale(x)
x = x + residual
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = self.ff_scale(x)
x = x + residual
else:
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
if context is not None and self.cross_attend:
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
if self.conformer is not None:
x = x + self.conformer_scale(self.conformer(x))
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
prepend_part = x[:, :prepend_length, :]
audio_part = x[:, prepend_length:, :]
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
modulated_audio_part = audio_part * (1 + scale) + shift
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in = None,
dim_out = None,
dim_heads = 64,
cross_attend=False,
cond_token_dim=None,
pre_cross_attn_ix=-1,
final_cross_attn_ix=-1,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
num_memory_tokens=0,
sliding_window=None,
use_mlp=False,
use_add_norm=False,
use_gated=False,
use_final_layer=False,
use_zeros=False,
use_conv=False,
use_fusion_mlp=False,
use_film=False,
use_sync_film=False,
use_sync_gated=False,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
if use_mlp:
self.project_in = nn.Sequential(
nn.Linear(dim_in, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim, bias=False)
)
else:
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
self.video_temporal_conv = None
self.audio_temporal_conv = None
self.fusion_mlp = None
if use_conv:
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
if use_fusion_mlp:
self.fusion_mlp = nn.Sequential(
nn.Linear(dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
self.adaLN_modulation = None
if global_cond_dim is not None:
if use_final_layer:
self.norm_final = LayerNorm(dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(
dim, 2 * dim, bias=True
),
)
if use_zeros:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.project_out.weight, 0)
self.global_cond_embedder = nn.Sequential(
nn.Linear(global_cond_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim * 6)
)
if use_zeros:
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
self.final_cross_attn_ix = final_cross_attn_ix
self.use_gated = use_gated
self.use_film = use_film
self.use_add_norm = use_add_norm
if self.use_add_norm:
self.add_norm = nn.LayerNorm(dim)
if use_gated:
self.gate = nn.Parameter(torch.ones(1, 1, dim))
if use_film:
self.film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.film_generator = None
if use_sync_film:
self.sync_film_generator = nn.Sequential(
nn.Linear(dim, dim, bias=False),
nn.SiLU(),
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
)
else:
self.sync_film_generator = None
if use_sync_gated:
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
else:
self.sync_gate = None
self.sliding_window = sliding_window
for i in range(depth):
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
# print(f"Layer {i} cross attends: {should_cross_attend}")
self.layers.append(
TransformerBlock(
dim,
dim_heads = dim_heads,
cross_attend = should_cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
**kwargs
)
)
def forward(
self,
x,
mask = None,
prepend_embeds = None,
prepend_mask = None,
add_cond = None,
sync_cond = None,
global_cond = None,
return_info = False,
use_checkpointing = True,
exit_layer_ix = None,
video_dropout_prob = 0.0,
**kwargs
):
batch, seq, device = *x.shape[:2], x.device
model_dtype = next(self.parameters()).dtype
x = x.to(model_dtype)
info = {
"hidden_states": [],
}
x = self.project_in(x)
if add_cond is not None:
if self.use_gated:
gate = torch.sigmoid(self.gate)
x = x + gate * add_cond
elif self.use_film:
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
else:
x = x + add_cond
if self.use_add_norm:
x = self.add_norm(x)
if self.fusion_mlp is not None:
x = self.fusion_mlp(x)
if sync_cond is not None:
if self.sync_film_generator is not None:
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
x = x * (1 + scale) + shift
elif self.sync_gate is not None:
gate_value = torch.sigmoid(self.sync_gate)
x = x + gate_value * sync_cond
# else:
# x = x + sync_cond
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim = -2)
if self.num_memory_tokens > 0:
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
if global_cond is not None and self.global_cond_embedder is not None:
global_cond_embed = self.global_cond_embedder(global_cond)
else:
global_cond_embed = global_cond
# Iterate over the transformer layers
for layer_ix, layer in enumerate(self.layers):
if use_checkpointing:
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
if return_info:
info["hidden_states"].append(x)
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
x = x[:, self.num_memory_tokens:, :]
if return_info:
return x, info
return x
x = x[:, self.num_memory_tokens:, :]
if global_cond is not None and self.adaLN_modulation is not None:
if len(global_cond.shape) == 2:
global_cond = global_cond.unsqueeze(1)
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.project_out(x)
if return_info:
return x, info
return x
+177
View File
@@ -0,0 +1,177 @@
import torch
from safetensors.torch import load_file
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
from torch.nn.utils import remove_weight_norm
def load_ckpt_state_dict(ckpt_path, prefix=None):
if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# 过滤特定前缀的state_dict
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
return filtered_state_dict
def remove_weight_norm_from_model(model):
for module in model.modules():
if hasattr(module, "weight"):
print(f"Removing weight norm from {module}")
remove_weight_norm(module)
return model
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
# License can be found in LICENSES/LICENSE_META.txt
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
if num_samples == 1:
q = torch.empty_like(input).exponential_(1, generator=generator)
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in “top-k”.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in “top-p”.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def next_power_of_two(n):
return 2 ** (n - 1).bit_length()
def next_multiple_of_64(n):
return ((n + 63) // 64) * 64
# mask construction helpers
def mask_from_start_end_indices(
seq_len: int,
start: Tensor,
end: Tensor
):
assert start.shape == end.shape
device = start.device
seq = torch.arange(seq_len, device = device, dtype = torch.long)
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
seq = seq.expand(*start.shape, seq_len)
mask = seq >= start[..., None].long()
mask &= seq < end[..., None].long()
return mask
def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths.device
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
def _build_spline(video_feat, video_t, target_t):
# 三次样条插值核心实现
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
spline = NaturalCubicSpline(coeffs)
return spline.evaluate(target_t).permute(0,2,1)
def resample(video_feat, audio_latent):
"""
9s
video_feat: [B, 72, D]
audio_latent: [B, D', 194] or int
"""
B, Tv, D = video_feat.shape
if isinstance(audio_latent, torch.Tensor):
# audio_latent is a tensor
if audio_latent.shape[1] != 64:
Ta = audio_latent.shape[1]
else:
Ta = audio_latent.shape[2]
elif isinstance(audio_latent, int):
# audio_latent is an int
Ta = audio_latent
else:
raise TypeError("audio_latent must be either a tensor or an int")
# 构建时间戳 (关键改进点)
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
# 三维化处理 (Batch, Feature, Time)
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
# 三次样条插值
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
import os
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
def compile(function, *args, **kwargs):
if enable_torch_compile:
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
return function