feat: extract prismaudio_core model modules (DiT, conditioners, VAE, diffusion)
Fetch and adapt inference-critical model modules from upstream PrismAudio repo: - dit.py: DiffusionTransformer with debug prints removed - diffusion.py: ConditionedDiffusionModelWrapper, DiTWrapper, MMDiTWrapper - conditioners.py: Cond_MLP, Sync_MLP, MultiConditioner with stubbed training imports - autoencoders.py: AudioAutoencoder, OobleckEncoder/Decoder - transformer.py: ContinuousTransformer, Attention with flash_attn fallback to SDPA - blocks.py, utils.py, bottleneck.py, pretransforms.py, local_attention.py, pqmf.py - adp.py: UNetCFG1d, UNet1d, NumberEmbedder - mmmodules/model/low_level.py: MLP, ChannelLastConv1d, ConvMLP All internal imports rewritten from PrismAudio.* to prismaudio_core.*, training-only imports stubbed, flash_attn made optional with HAS_FLASH_ATTN flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
PrismAudio model modules for inference.
|
||||
|
||||
Re-exports create_model_from_config from the factory module.
|
||||
"""
|
||||
|
||||
from prismaudio_core.factory import create_model_from_config
|
||||
|
||||
__all__ = ["create_model_from_config"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,830 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchaudio import transforms as T
|
||||
from alias_free_torch import Activation1d
|
||||
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
||||
from typing import Literal, Dict, Any
|
||||
|
||||
from .blocks import SnakeBeta
|
||||
from .bottleneck import Bottleneck, DiscreteBottleneck
|
||||
from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
|
||||
from .pretransforms import Pretransform
|
||||
|
||||
|
||||
def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
|
||||
"""Minimal stub for inference.utils.prepare_audio used by autoencoders."""
|
||||
import torchaudio.transforms as T
|
||||
import torch
|
||||
|
||||
if in_sr != target_sr:
|
||||
resample_tf = T.Resample(in_sr, target_sr).to(device)
|
||||
audio = resample_tf(audio)
|
||||
|
||||
if audio.shape[0] > target_channels:
|
||||
audio = audio[:target_channels]
|
||||
elif audio.shape[0] < target_channels:
|
||||
audio = audio.repeat(target_channels // audio.shape[0] + 1, 1)[:target_channels]
|
||||
|
||||
if audio.shape[-1] < target_length:
|
||||
audio = torch.nn.functional.pad(audio, (0, target_length - audio.shape[-1]))
|
||||
elif audio.shape[-1] > target_length:
|
||||
audio = audio[..., :target_length]
|
||||
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def _lazy_create_pretransform_from_config(pretransform, sample_rate):
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
|
||||
def _lazy_create_bottleneck_from_config(bottleneck):
|
||||
from prismaudio_core.factory import create_bottleneck_from_config
|
||||
return create_bottleneck_from_config(bottleneck)
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
|
||||
if activation == "elu":
|
||||
act = nn.ELU()
|
||||
elif activation == "snake":
|
||||
act = SnakeBeta(channels)
|
||||
elif activation == "none":
|
||||
act = nn.Identity()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation {activation}")
|
||||
|
||||
if antialias:
|
||||
act = Activation1d(act)
|
||||
|
||||
return act
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.dilation = dilation
|
||||
|
||||
padding = (dilation * (7-1)) // 2
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=7, dilation=dilation, padding=padding),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
|
||||
WNConv1d(in_channels=out_channels, out_channels=out_channels,
|
||||
kernel_size=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
|
||||
#x = checkpoint(self.layers, x)
|
||||
x = self.layers(x)
|
||||
|
||||
return x + res
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=in_channels,
|
||||
out_channels=in_channels, dilation=9, use_snake=use_snake),
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
WNConv1d(in_channels=in_channels, out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
|
||||
super().__init__()
|
||||
|
||||
if use_nearest_upsample:
|
||||
upsample_layer = nn.Sequential(
|
||||
nn.Upsample(scale_factor=stride, mode="nearest"),
|
||||
WNConv1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride,
|
||||
stride=1,
|
||||
bias=False,
|
||||
padding='same')
|
||||
)
|
||||
else:
|
||||
upsample_layer = WNConvTranspose1d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
|
||||
upsample_layer,
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=1, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=3, use_snake=use_snake),
|
||||
ResidualUnit(in_channels=out_channels, out_channels=out_channels,
|
||||
dilation=9, use_snake=use_snake),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
|
||||
]
|
||||
|
||||
for i in range(self.depth-1):
|
||||
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
|
||||
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
out_channels=2,
|
||||
channels=128,
|
||||
latent_dim=32,
|
||||
c_mults = [1, 2, 4, 8],
|
||||
strides = [2, 4, 8, 8],
|
||||
use_snake=False,
|
||||
antialias_activation=False,
|
||||
use_nearest_upsample=False,
|
||||
final_tanh=True):
|
||||
super().__init__()
|
||||
|
||||
c_mults = [1] + c_mults
|
||||
|
||||
self.depth = len(c_mults)
|
||||
|
||||
layers = [
|
||||
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
|
||||
]
|
||||
|
||||
for i in range(self.depth-1, 0, -1):
|
||||
layers += [DecoderBlock(
|
||||
in_channels=c_mults[i]*channels,
|
||||
out_channels=c_mults[i-1]*channels,
|
||||
stride=strides[i-1],
|
||||
use_snake=use_snake,
|
||||
antialias_activation=antialias_activation,
|
||||
use_nearest_upsample=use_nearest_upsample
|
||||
)
|
||||
]
|
||||
|
||||
layers += [
|
||||
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
|
||||
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
|
||||
nn.Tanh() if final_tanh else nn.Identity()
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class DACEncoderWrapper(nn.Module):
|
||||
def __init__(self, in_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Encoder as DACEncoder
|
||||
|
||||
latent_dim = kwargs.pop("latent_dim", None)
|
||||
|
||||
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
||||
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
||||
self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
|
||||
|
||||
if in_channels != 1:
|
||||
self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
class DACDecoderWrapper(nn.Module):
|
||||
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
from dac.model.dac import Decoder as DACDecoder
|
||||
|
||||
self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
class AudioAutoencoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder,
|
||||
decoder,
|
||||
latent_dim,
|
||||
downsampling_ratio,
|
||||
sample_rate,
|
||||
io_channels=2,
|
||||
bottleneck: Bottleneck = None,
|
||||
pretransform: Pretransform = None,
|
||||
in_channels = None,
|
||||
out_channels = None,
|
||||
soft_clip = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsampling_ratio = downsampling_ratio
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.io_channels = io_channels
|
||||
self.in_channels = io_channels
|
||||
self.out_channels = io_channels
|
||||
|
||||
self.min_length = self.downsampling_ratio
|
||||
|
||||
if in_channels is not None:
|
||||
self.in_channels = in_channels
|
||||
|
||||
if out_channels is not None:
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.decoder = decoder
|
||||
|
||||
self.pretransform = pretransform
|
||||
|
||||
self.soft_clip = soft_clip
|
||||
|
||||
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
||||
|
||||
def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
|
||||
|
||||
info = {}
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
if self.pretransform is not None and not skip_pretransform:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
audios = []
|
||||
for i in range(audio.shape[0]):
|
||||
audios.append(self.pretransform.encode(audio[i:i+1]))
|
||||
audio = torch.cat(audios, dim=0)
|
||||
else:
|
||||
audio = self.pretransform.encode(audio)
|
||||
|
||||
if self.encoder is not None:
|
||||
if iterate_batch:
|
||||
latents = []
|
||||
for i in range(audio.shape[0]):
|
||||
latents.append(self.encoder(audio[i:i+1]))
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
latents = self.encoder(audio)
|
||||
else:
|
||||
latents = audio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
# TODO: Add iterate batch logic, needs to merge the info dicts
|
||||
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
||||
|
||||
info.update(bottleneck_info)
|
||||
|
||||
if return_info:
|
||||
return latents, info
|
||||
|
||||
return latents
|
||||
|
||||
def decode(self, latents, iterate_batch=False, **kwargs):
|
||||
|
||||
if self.bottleneck is not None:
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.bottleneck.decode(latents[i:i+1]))
|
||||
latents = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if iterate_batch:
|
||||
decoded = []
|
||||
for i in range(latents.shape[0]):
|
||||
decoded.append(self.decoder(latents[i:i+1]))
|
||||
decoded = torch.cat(decoded, dim=0)
|
||||
else:
|
||||
decoded = self.decoder(latents, **kwargs)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(decoded.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if iterate_batch:
|
||||
decodeds = []
|
||||
for i in range(latents.shape[0]):
|
||||
decodeds.append(self.pretransform.decode(decoded[i:i+1]))
|
||||
decoded = torch.cat(decodeds, dim=0)
|
||||
else:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
if self.soft_clip:
|
||||
decoded = torch.tanh(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
'''
|
||||
Decode discrete tokens to audio
|
||||
Only works with discrete autoencoders
|
||||
'''
|
||||
|
||||
assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
|
||||
|
||||
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
|
||||
def preprocess_audio_for_encoder(self, audio, in_sr):
|
||||
'''
|
||||
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
||||
If the model is mono, stereo audio will be converted to mono.
|
||||
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
||||
Audio will be resampled to the model's sample rate.
|
||||
The output will have batch size 1 and be shape (1 x Channels x Length)
|
||||
'''
|
||||
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
||||
|
||||
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
||||
'''
|
||||
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
||||
The audio in that list can be of different lengths and channels.
|
||||
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
||||
All audio will be resampled to the model's sample rate.
|
||||
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
||||
If the model is mono, all audio will be converted to mono.
|
||||
The output will be a tensor of shape (Batch x Channels x Length)
|
||||
'''
|
||||
batch_size = len(audio_list)
|
||||
if isinstance(in_sr_list, int):
|
||||
in_sr_list = [in_sr_list]*batch_size
|
||||
assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
|
||||
new_audio = []
|
||||
max_length = 0
|
||||
# resample & find the max length
|
||||
for i in range(batch_size):
|
||||
audio = audio_list[i]
|
||||
in_sr = in_sr_list[i]
|
||||
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
||||
# batchsize 1 was given by accident. Just squeeze it.
|
||||
audio = audio.squeeze(0)
|
||||
elif len(audio.shape) == 1:
|
||||
# Mono signal, channel dimension is missing, unsqueeze it in
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
|
||||
# Resample audio
|
||||
if in_sr != self.sample_rate:
|
||||
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
||||
audio = resample_tf(audio)
|
||||
new_audio.append(audio)
|
||||
if audio.shape[-1] > max_length:
|
||||
max_length = audio.shape[-1]
|
||||
# Pad every audio to the same length, multiple of model's downsampling ratio
|
||||
padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
|
||||
for i in range(batch_size):
|
||||
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
||||
new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
|
||||
target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
|
||||
# convert to tensor
|
||||
return torch.stack(new_audio)
|
||||
|
||||
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
||||
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
||||
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
||||
# and therefore you likely could use the same values with decode_audio.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Encode the entire audio in parallel
|
||||
return self.encode(audio, **kwargs)
|
||||
else:
|
||||
# CHUNKED ENCODING
|
||||
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
total_size = audio.shape[2] # in samples
|
||||
print(f'audio shape: {audio.shape}')
|
||||
batch_size = audio.shape[0]
|
||||
chunk_size *= samples_per_latent # converting metric in latents to samples
|
||||
overlap *= samples_per_latent # converting metric in latents to samples
|
||||
hop_size = chunk_size - overlap
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = audio[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = audio[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# Note: y_size might be a different value from the latent length used in diffusion training
|
||||
# because we can encode audio of varying lengths
|
||||
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
||||
y_size = total_size // samples_per_latent
|
||||
# Create an empty latent, we will populate it with chunks as we encode them
|
||||
y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
|
||||
print(f'y_final shape: {y_final.shape}')
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# encode the chunk
|
||||
y_chunk = self.encode(x_chunk)
|
||||
print(f'y_chunk shape: {y_chunk.shape}')
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size // samples_per_latent
|
||||
t_end = t_start + chunk_size // samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = overlap//samples_per_latent//2
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
||||
'''
|
||||
Decode latents to audio.
|
||||
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
||||
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
||||
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
||||
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
||||
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
||||
Smaller chunk_size uses less memory, but more compute.
|
||||
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
||||
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
||||
'''
|
||||
if not chunked:
|
||||
# default behavior. Decode the entire latent in parallel
|
||||
return self.decode(latents, **kwargs)
|
||||
else:
|
||||
# chunked decoding
|
||||
hop_size = chunk_size - overlap
|
||||
total_size = latents.shape[2]
|
||||
batch_size = latents.shape[0]
|
||||
chunks = []
|
||||
for i in range(0, total_size - chunk_size + 1, hop_size):
|
||||
chunk = latents[:,:,i:i+chunk_size]
|
||||
chunks.append(chunk)
|
||||
if i+chunk_size != total_size:
|
||||
# Final chunk
|
||||
chunk = latents[:,:,-chunk_size:]
|
||||
chunks.append(chunk)
|
||||
chunks = torch.stack(chunks)
|
||||
num_chunks = chunks.shape[0]
|
||||
# samples_per_latent is just the downsampling ratio
|
||||
samples_per_latent = self.downsampling_ratio
|
||||
# Create an empty waveform, we will populate it with chunks as decode them
|
||||
y_size = total_size * samples_per_latent
|
||||
y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
|
||||
for i in range(num_chunks):
|
||||
x_chunk = chunks[i,:]
|
||||
# decode the chunk
|
||||
y_chunk = self.decode(x_chunk)
|
||||
# figure out where to put the audio along the time domain
|
||||
if i == num_chunks-1:
|
||||
# final chunk always goes at the end
|
||||
t_end = y_size
|
||||
t_start = t_end - y_chunk.shape[2]
|
||||
else:
|
||||
t_start = i * hop_size * samples_per_latent
|
||||
t_end = t_start + chunk_size * samples_per_latent
|
||||
# remove the edges of the overlaps
|
||||
ol = (overlap//2) * samples_per_latent
|
||||
chunk_start = 0
|
||||
chunk_end = y_chunk.shape[2]
|
||||
if i > 0:
|
||||
# no overlap for the start of the first chunk
|
||||
t_start += ol
|
||||
chunk_start += ol
|
||||
if i < num_chunks-1:
|
||||
# no overlap for the end of the last chunk
|
||||
t_end -= ol
|
||||
chunk_end -= ol
|
||||
# paste the chunked audio into our y_final output audio
|
||||
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
||||
return y_final
|
||||
|
||||
|
||||
class DiffusionAutoencoder(AudioAutoencoder):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion: ConditionedDiffusionModel,
|
||||
diffusion_downsampling_ratio,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.diffusion = diffusion
|
||||
|
||||
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
||||
|
||||
if self.encoder is not None:
|
||||
# Shrink the initial encoder parameters to avoid saturated latents
|
||||
with torch.no_grad():
|
||||
for param in self.encoder.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def decode(self, latents, steps=100):
|
||||
|
||||
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
||||
|
||||
if self.bottleneck is not None:
|
||||
latents = self.bottleneck.decode(latents)
|
||||
|
||||
if self.decoder is not None:
|
||||
latents = self.decode(latents)
|
||||
|
||||
# Upsample latents to match diffusion length
|
||||
if latents.shape[2] != upsampled_length:
|
||||
latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
|
||||
|
||||
noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
|
||||
from prismaudio_core.inference.sampling import sample
|
||||
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
||||
|
||||
if self.pretransform is not None:
|
||||
if self.pretransform.enable_grad:
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
decoded = self.pretransform.decode(decoded)
|
||||
|
||||
return decoded
|
||||
|
||||
# AE factories
|
||||
|
||||
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
||||
encoder_type = encoder_config.get("type", None)
|
||||
assert encoder_type is not None, "Encoder type must be specified"
|
||||
|
||||
if encoder_type == "oobleck":
|
||||
encoder = OobleckEncoder(
|
||||
**encoder_config["config"]
|
||||
)
|
||||
|
||||
elif encoder_type == "seanet":
|
||||
from encodec.modules import SEANetEncoder
|
||||
seanet_encoder_config = encoder_config["config"]
|
||||
|
||||
#SEANet encoder expects strides in reverse order
|
||||
seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
|
||||
encoder = SEANetEncoder(
|
||||
**seanet_encoder_config
|
||||
)
|
||||
elif encoder_type == "dac":
|
||||
dac_config = encoder_config["config"]
|
||||
|
||||
encoder = DACEncoderWrapper(**dac_config)
|
||||
elif encoder_type == "local_attn":
|
||||
from .local_attention import TransformerEncoder1D
|
||||
|
||||
local_attn_config = encoder_config["config"]
|
||||
|
||||
encoder = TransformerEncoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder type {encoder_type}")
|
||||
|
||||
requires_grad = encoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return encoder
|
||||
|
||||
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
||||
decoder_type = decoder_config.get("type", None)
|
||||
assert decoder_type is not None, "Decoder type must be specified"
|
||||
|
||||
if decoder_type == "oobleck":
|
||||
decoder = OobleckDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "seanet":
|
||||
from encodec.modules import SEANetDecoder
|
||||
|
||||
decoder = SEANetDecoder(
|
||||
**decoder_config["config"]
|
||||
)
|
||||
elif decoder_type == "dac":
|
||||
dac_config = decoder_config["config"]
|
||||
|
||||
decoder = DACDecoderWrapper(**dac_config)
|
||||
elif decoder_type == "local_attn":
|
||||
from .local_attention import TransformerDecoder1D
|
||||
|
||||
local_attn_config = decoder_config["config"]
|
||||
|
||||
decoder = TransformerDecoder1D(
|
||||
**local_attn_config
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder type {decoder_type}")
|
||||
|
||||
requires_grad = decoder_config.get("requires_grad", True)
|
||||
if not requires_grad:
|
||||
for param in decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return decoder
|
||||
|
||||
def create_autoencoder_from_config(config: Dict[str, Any]):
|
||||
|
||||
ae_config = config["model"]
|
||||
|
||||
encoder = create_encoder_from_config(ae_config["encoder"])
|
||||
decoder = create_decoder_from_config(ae_config["decoder"])
|
||||
|
||||
bottleneck = ae_config.get("bottleneck", None)
|
||||
|
||||
latent_dim = ae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = ae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
in_channels = ae_config.get("in_channels", None)
|
||||
out_channels = ae_config.get("out_channels", None)
|
||||
|
||||
pretransform = ae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
||||
|
||||
return AudioAutoencoder(
|
||||
encoder,
|
||||
decoder,
|
||||
io_channels=io_channels,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
sample_rate=sample_rate,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
soft_clip=soft_clip
|
||||
)
|
||||
|
||||
def create_diffAE_from_config(config: Dict[str, Any]):
|
||||
|
||||
diffae_config = config["model"]
|
||||
|
||||
if "encoder" in diffae_config:
|
||||
encoder = create_encoder_from_config(diffae_config["encoder"])
|
||||
else:
|
||||
encoder = None
|
||||
|
||||
if "decoder" in diffae_config:
|
||||
decoder = create_decoder_from_config(diffae_config["decoder"])
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
diffusion_model_type = diffae_config["diffusion"]["type"]
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
||||
|
||||
latent_dim = diffae_config.get("latent_dim", None)
|
||||
assert latent_dim is not None, "latent_dim must be specified in model config"
|
||||
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
||||
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
||||
io_channels = diffae_config.get("io_channels", None)
|
||||
assert io_channels is not None, "io_channels must be specified in model config"
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "sample_rate must be specified in model config"
|
||||
|
||||
bottleneck = diffae_config.get("bottleneck", None)
|
||||
|
||||
pretransform = diffae_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _lazy_create_pretransform_from_config(pretransform, sample_rate)
|
||||
|
||||
if bottleneck is not None:
|
||||
bottleneck = _lazy_create_bottleneck_from_config(bottleneck)
|
||||
|
||||
diffusion_downsampling_ratio = None,
|
||||
|
||||
if diffusion_model_type == "DAU1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
|
||||
elif diffusion_model_type == "adp_1d":
|
||||
diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
diffusion_downsampling_ratio = 1
|
||||
|
||||
return DiffusionAutoencoder(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
diffusion=diffusion,
|
||||
io_channels=io_channels,
|
||||
sample_rate=sample_rate,
|
||||
latent_dim=latent_dim,
|
||||
downsampling_ratio=downsampling_ratio,
|
||||
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
||||
bottleneck=bottleneck,
|
||||
pretransform=pretransform
|
||||
)
|
||||
@@ -0,0 +1,339 @@
|
||||
from functools import reduce
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
from packaging import version
|
||||
|
||||
from dac.nn.layers import Snake1d
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, main, skip=None):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
self.skip = skip if skip else nn.Identity()
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input) + self.skip(input)
|
||||
|
||||
class ResConvBlock(ResidualBlock):
|
||||
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
|
||||
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
||||
super().__init__([
|
||||
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_mid),
|
||||
Snake1d(c_mid) if use_snake else nn.GELU(),
|
||||
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
|
||||
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
||||
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
|
||||
], skip)
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, c_in, n_head=1, dropout_rate=0.):
|
||||
super().__init__()
|
||||
assert c_in % n_head == 0
|
||||
self.norm = nn.GroupNorm(1, c_in)
|
||||
self.n_head = n_head
|
||||
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
||||
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
|
||||
|
||||
if not self.use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
# Use flash attention for A100 GPUs
|
||||
self.sdp_kernel_config = (True, False, False)
|
||||
else:
|
||||
# Don't use flash attention for other GPUs
|
||||
self.sdp_kernel_config = (False, True, True)
|
||||
|
||||
def forward(self, input):
|
||||
n, c, s = input.shape
|
||||
qkv = self.qkv_proj(self.norm(input))
|
||||
qkv = qkv.view(
|
||||
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = k.shape[3]**-0.25
|
||||
|
||||
if self.use_flash:
|
||||
with sdp_kernel(*self.sdp_kernel_config):
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
|
||||
else:
|
||||
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
||||
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
||||
|
||||
|
||||
return input + self.dropout(self.out_proj(y))
|
||||
|
||||
class SkipBlock(nn.Module):
|
||||
def __init__(self, *main):
|
||||
super().__init__()
|
||||
self.main = nn.Sequential(*main)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.cat([self.main(input), input], dim=1)
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1.):
|
||||
super().__init__()
|
||||
assert out_features % 2 == 0
|
||||
self.weight = nn.Parameter(torch.randn(
|
||||
[out_features // 2, in_features]) * std)
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
def expand_to_planes(input, shape):
|
||||
return input[..., None].repeat([1, 1, shape[2]])
|
||||
|
||||
_kernels = {
|
||||
'linear':
|
||||
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
'cubic':
|
||||
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
|
||||
0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
'lanczos3':
|
||||
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
|
||||
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
|
||||
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
|
||||
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
|
||||
}
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv1d(x, weight, stride=2)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer('kernel', kernel_1d)
|
||||
self.channels_last = channels_last
|
||||
|
||||
def forward(self, x):
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
if self.channels_last:
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
def Downsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
||||
) -> nn.Module:
|
||||
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
||||
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * kernel_multiplier + 1,
|
||||
stride=factor,
|
||||
padding=factor * (kernel_multiplier // 2),
|
||||
)
|
||||
|
||||
|
||||
def Upsample1d_2(
|
||||
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
||||
) -> nn.Module:
|
||||
|
||||
if factor == 1:
|
||||
return nn.Conv1d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
||||
)
|
||||
|
||||
if use_nearest:
|
||||
return nn.Sequential(
|
||||
nn.Upsample(scale_factor=factor, mode="nearest"),
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
),
|
||||
)
|
||||
else:
|
||||
return nn.ConvTranspose1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=factor * 2,
|
||||
stride=factor,
|
||||
padding=factor // 2 + factor % 2,
|
||||
output_padding=factor % 2,
|
||||
)
|
||||
|
||||
def zero_init(layer):
|
||||
nn.init.zeros_(layer.weight)
|
||||
if layer.bias is not None:
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
def rms_norm(x, scale, eps):
|
||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||
return x * scale.to(x.dtype)
|
||||
|
||||
#rms_norm = torch.compile(rms_norm)
|
||||
|
||||
class AdaRMSNorm(nn.Module):
|
||||
def __init__(self, features, cond_features, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"eps={self.eps},"
|
||||
|
||||
def forward(self, x, cond):
|
||||
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
||||
|
||||
def normalize(x, eps=1e-4):
|
||||
dim = list(range(1, x.ndim))
|
||||
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
||||
alpha = np.sqrt(n.numel() / x.numel())
|
||||
return x / torch.add(eps, n, alpha=alpha)
|
||||
|
||||
class ForcedWNConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
with torch.no_grad():
|
||||
self.weight.copy_(normalize(self.weight))
|
||||
|
||||
fan_in = self.weight[0].numel()
|
||||
|
||||
w = normalize(self.weight) / math.sqrt(fan_in)
|
||||
|
||||
return F.conv1d(x, w, padding='same')
|
||||
|
||||
# Kernels
|
||||
|
||||
use_compile = True
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
if not use_compile:
|
||||
return function
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
|
||||
@compile
|
||||
def linear_geglu(x, weight, bias=None):
|
||||
x = x @ weight.mT
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
@compile
|
||||
def rms_norm(x, scale, eps):
|
||||
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
||||
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
||||
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
||||
return x * scale.to(x.dtype)
|
||||
|
||||
# Layers
|
||||
|
||||
class LinearGEGLU(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super().__init__(in_features, out_features * 2, bias=bias)
|
||||
self.out_features = out_features
|
||||
|
||||
def forward(self, x):
|
||||
return linear_geglu(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, shape, fix_scale = False, eps=1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("scale", torch.ones(shape))
|
||||
else:
|
||||
self.scale = nn.Parameter(torch.ones(shape))
|
||||
|
||||
def extra_repr(self):
|
||||
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.scale, self.eps)
|
||||
|
||||
def snake_beta(x, alpha, beta):
|
||||
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
||||
|
||||
# try:
|
||||
# snake_beta = torch.compile(snake_beta)
|
||||
# except RuntimeError:
|
||||
# pass
|
||||
|
||||
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
||||
# License available in LICENSES/LICENSE_NVIDIA.txt
|
||||
class SnakeBeta(nn.Module):
|
||||
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
||||
super(SnakeBeta, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
beta = torch.exp(beta)
|
||||
x = snake_beta(x, alpha, beta)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,355 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from vector_quantize_pytorch import ResidualVQ, FSQ
|
||||
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, is_discrete: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
class DiscreteBottleneck(Bottleneck):
|
||||
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
||||
super().__init__(is_discrete=True)
|
||||
|
||||
self.num_quantizers = num_quantizers
|
||||
self.codebook_size = codebook_size
|
||||
self.tokens_id = tokens_id
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class TanhBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = torch.tanh(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def vae_sample(mean, scale):
|
||||
stdev = nn.functional.softplus(scale) + 1e-4
|
||||
var = stdev * stdev
|
||||
logvar = torch.log(var)
|
||||
latents = torch.randn_like(mean) * stdev + mean
|
||||
|
||||
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
||||
|
||||
return latents, kl
|
||||
|
||||
class VAEBottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def compute_mean_kernel(x, y):
|
||||
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
||||
return torch.exp(-kernel_input).mean()
|
||||
|
||||
def compute_mmd(latents):
|
||||
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
||||
noise = torch.randn_like(latents_reshaped)
|
||||
|
||||
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
||||
noise_kernel = compute_mean_kernel(noise, noise)
|
||||
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
||||
|
||||
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
||||
return mmd.mean()
|
||||
|
||||
class WassersteinBottleneck(Bottleneck):
|
||||
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
self.bypass_mmd = bypass_mmd
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
if self.training and return_info:
|
||||
if self.bypass_mmd:
|
||||
mmd = torch.tensor(0.0)
|
||||
else:
|
||||
mmd = compute_mmd(x)
|
||||
|
||||
info["mmd"] = mmd
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class L2Bottleneck(Bottleneck):
|
||||
def __init__(self):
|
||||
super().__init__(is_discrete=False)
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x = F.normalize(x, dim=1)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return F.normalize(x, dim=1)
|
||||
|
||||
class RVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class RVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
|
||||
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
x, kl = vae_sample(*x.chunk(2, dim=1))
|
||||
|
||||
info["kl"] = kl
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices, loss = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
info["quantizer_loss"] = loss.mean()
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents = self.quantizer.get_outputs_from_indices(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
def encode(self, x, return_info=False, **kwargs):
|
||||
info = {}
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
||||
super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
|
||||
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
||||
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
def encode(self, x, return_info=False, n_quantizers: int = None):
|
||||
info = {}
|
||||
|
||||
mean, scale = x.chunk(2, dim=1)
|
||||
|
||||
x, kl = vae_sample(mean, scale)
|
||||
|
||||
info["pre_quantizer"] = x
|
||||
info["kl"] = kl
|
||||
|
||||
if self.quantize_on_decode:
|
||||
return x, info if return_info else x
|
||||
|
||||
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
|
||||
|
||||
output = {
|
||||
"z": z,
|
||||
"codes": codes,
|
||||
"latents": latents,
|
||||
"vq/commitment_loss": commitment_loss,
|
||||
"vq/codebook_loss": codebook_loss,
|
||||
}
|
||||
|
||||
output["vq/commitment_loss"] /= self.num_quantizers
|
||||
output["vq/codebook_loss"] /= self.num_quantizers
|
||||
|
||||
info.update(output)
|
||||
|
||||
if return_info:
|
||||
return output["z"], info
|
||||
|
||||
return output["z"]
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.quantize_on_decode:
|
||||
x = self.quantizer(x)[0]
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, codes, **kwargs):
|
||||
latents, _, _ = self.quantizer.from_codes(codes)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
|
||||
class FSQBottleneck(DiscreteBottleneck):
|
||||
def __init__(self, noise_augment_dim=0, **kwargs):
|
||||
super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
|
||||
|
||||
self.noise_augment_dim = noise_augment_dim
|
||||
|
||||
self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
|
||||
|
||||
def encode(self, x, return_info=False):
|
||||
info = {}
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x, indices = self.quantizer(x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
x = x.to(orig_dtype)
|
||||
|
||||
# Reorder indices to match the expected format
|
||||
indices = rearrange(indices, "b n q -> b q n")
|
||||
|
||||
info["quantizer_indices"] = indices
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
else:
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
|
||||
if self.noise_augment_dim > 0:
|
||||
noise = torch.randn(x.shape[0], self.noise_augment_dim,
|
||||
x.shape[-1]).type_as(x)
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
latents = self.quantizer.indices_to_codes(tokens)
|
||||
|
||||
return self.decode(latents, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,965 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import typing as tp
|
||||
|
||||
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
||||
from .conditioners import MultiConditioner
|
||||
from .dit import DiffusionTransformer
|
||||
from .pretransforms import Pretransform
|
||||
|
||||
from .adp import UNetCFG1d, UNet1d
|
||||
|
||||
# Lazy imports for factory functions to avoid circular imports
|
||||
def _get_create_pretransform_from_config():
|
||||
from prismaudio_core.factory import create_pretransform_from_config
|
||||
return create_pretransform_from_config
|
||||
|
||||
def _get_create_multi_conditioner_from_conditioning_config():
|
||||
from prismaudio_core.factory import create_multi_conditioner_from_conditioning_config
|
||||
return create_multi_conditioner_from_conditioning_config
|
||||
|
||||
from time import time
|
||||
|
||||
class Profiler:
|
||||
|
||||
def __init__(self):
|
||||
self.ticks = [[time(), None]]
|
||||
|
||||
def tick(self, msg):
|
||||
self.ticks.append([time(), msg])
|
||||
|
||||
def __repr__(self):
|
||||
rep = 80 * "=" + "\n"
|
||||
for i in range(1, len(self.ticks)):
|
||||
msg = self.ticks[i][1]
|
||||
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
||||
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
||||
rep += 80 * "=" + "\n\n\n"
|
||||
return rep
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class DiffusionModelWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model: DiffusionModel,
|
||||
io_channels,
|
||||
sample_size,
|
||||
sample_rate,
|
||||
min_input_length,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.io_channels = io_channels
|
||||
self.sample_size = sample_size
|
||||
self.sample_rate = sample_rate
|
||||
self.min_input_length = min_input_length
|
||||
|
||||
self.model = model
|
||||
|
||||
if pretransform is not None:
|
||||
self.pretransform = pretransform
|
||||
else:
|
||||
self.pretransform = None
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class ConditionedDiffusionModel(nn.Module):
|
||||
def __init__(self,
|
||||
*args,
|
||||
supports_cross_attention: bool = False,
|
||||
supports_input_concat: bool = False,
|
||||
supports_global_cond: bool = False,
|
||||
supports_prepend_cond: bool = False,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supports_cross_attention = supports_cross_attention
|
||||
self.supports_input_concat = supports_input_concat
|
||||
self.supports_global_cond = supports_global_cond
|
||||
self.supports_prepend_cond = supports_prepend_cond
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cross_attn_cond: torch.Tensor = None,
|
||||
cross_attn_mask: torch.Tensor = None,
|
||||
input_concat_cond: torch.Tensor = None,
|
||||
global_embed: torch.Tensor = None,
|
||||
prepend_cond: torch.Tensor = None,
|
||||
prepend_cond_mask: torch.Tensor = None,
|
||||
cfg_scale: float = 1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
**kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
class ConditionedDiffusionModelWrapper(nn.Module):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: ConditionedDiffusionModel,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
zero_init: bool = False,
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
sync_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.sync_cond_ids = sync_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
if zero_init is True:
|
||||
self.conditioner.apply(_basic_init)
|
||||
self.model.model.initialize_weights()
|
||||
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
sync_input = None
|
||||
|
||||
if len(self.cross_attn_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
cross_attention_input = []
|
||||
cross_attention_masks = []
|
||||
|
||||
for key in self.cross_attn_cond_ids:
|
||||
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(cross_attn_in.shape) == 2:
|
||||
cross_attn_in = cross_attn_in.unsqueeze(1)
|
||||
# cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
||||
|
||||
cross_attention_input.append(cross_attn_in)
|
||||
cross_attention_masks.append(cross_attn_mask)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
||||
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
||||
|
||||
if len(self.add_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
add_input = []
|
||||
|
||||
for key in self.add_cond_ids:
|
||||
add_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(add_in.shape) == 2:
|
||||
add_in = add_in.unsqueeze(1)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
# add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
|
||||
# add_in = add_in.transpose(1,2)
|
||||
add_input.append(add_in)
|
||||
|
||||
add_input = torch.cat(add_input, dim=2)
|
||||
|
||||
if len(self.sync_cond_ids) > 0:
|
||||
# Concatenate all cross-attention inputs over the sequence dimension
|
||||
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
||||
sync_input = []
|
||||
|
||||
for key in self.sync_cond_ids:
|
||||
sync_in = conditioning_tensors[key][0]
|
||||
|
||||
# Add sequence dimension if it's not there
|
||||
if len(sync_in.shape) == 2:
|
||||
sync_in = sync_in.unsqueeze(1)
|
||||
sync_input.append(sync_in)
|
||||
|
||||
sync_input = torch.cat(sync_input, dim=2)
|
||||
|
||||
if len(self.global_cond_ids) > 0:
|
||||
# Concatenate all global conditioning inputs over the channel dimension
|
||||
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
||||
global_conds = []
|
||||
for key in self.global_cond_ids:
|
||||
global_cond_input = conditioning_tensors[key][0]
|
||||
if len(global_cond_input.shape) == 2:
|
||||
global_cond_input = global_cond_input.unsqueeze(1)
|
||||
global_conds.append(global_cond_input)
|
||||
|
||||
# # Concatenate over the channel dimension
|
||||
# if global_conds[0].shape[-1] == 768:
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
# else:
|
||||
# global_cond = sum(global_conds)
|
||||
global_cond = sum(global_conds)
|
||||
# global_cond = torch.cat(global_conds, dim=-1)
|
||||
|
||||
if len(global_cond.shape) == 3:
|
||||
global_cond = global_cond.squeeze(1)
|
||||
|
||||
if len(self.input_concat_ids) > 0:
|
||||
# Concatenate all input concat conditioning inputs over the channel dimension
|
||||
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
||||
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
||||
|
||||
if len(self.prepend_cond_ids) > 0:
|
||||
# Concatenate all prepend conditioning inputs over the sequence dimension
|
||||
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
||||
prepend_conds = []
|
||||
prepend_cond_masks = []
|
||||
|
||||
for key in self.prepend_cond_ids:
|
||||
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
||||
if len(prepend_cond_input.shape) == 2:
|
||||
prepend_cond_input = prepend_cond_input.unsqueeze(1)
|
||||
prepend_conds.append(prepend_cond_input)
|
||||
prepend_cond_masks.append(prepend_cond_mask)
|
||||
|
||||
prepend_cond = torch.cat(prepend_conds, dim=1)
|
||||
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
||||
|
||||
if negative:
|
||||
return {
|
||||
"negative_cross_attn_cond": cross_attention_input,
|
||||
"negative_cross_attn_mask": cross_attention_masks,
|
||||
"negative_global_cond": global_cond,
|
||||
"negative_input_concat_cond": input_concat_cond
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"cross_attn_cond": cross_attention_input,
|
||||
"cross_attn_mask": cross_attention_masks,
|
||||
"global_cond": global_cond,
|
||||
"input_concat_cond": input_concat_cond,
|
||||
"prepend_cond": prepend_cond,
|
||||
"prepend_cond_mask": prepend_cond_mask,
|
||||
"add_cond": add_input,
|
||||
"sync_cond": sync_input
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNetCFG1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
**kwargs):
|
||||
p = Profiler()
|
||||
|
||||
p.tick("start")
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
embedding=cross_attn_cond,
|
||||
embedding_mask=cross_attn_mask,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
embedding_scale=cfg_scale,
|
||||
embedding_mask_proba=cfg_dropout_prob,
|
||||
batch_cfg=batch_cfg,
|
||||
rescale_cfg=rescale_cfg,
|
||||
negative_embedding=negative_cross_attn_cond,
|
||||
negative_embedding_mask=negative_cross_attn_mask,
|
||||
**kwargs)
|
||||
|
||||
p.tick("UNetCFG1D forward")
|
||||
|
||||
#print(f"Profiler: {p}")
|
||||
return outputs
|
||||
|
||||
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
||||
|
||||
self.model = UNet1d(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
global_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
**kwargs):
|
||||
|
||||
channels_list = None
|
||||
if input_concat_cond is not None:
|
||||
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
if input_concat_cond.shape[2] != x.shape[2]:
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
|
||||
channels_list = [input_concat_cond]
|
||||
|
||||
outputs = self.model(
|
||||
x,
|
||||
t,
|
||||
features=global_cond,
|
||||
channels_list=channels_list,
|
||||
**kwargs)
|
||||
|
||||
return outputs
|
||||
|
||||
class UNet1DUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = in_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
||||
|
||||
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
input_concat_cond=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
global_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = False,
|
||||
rescale_cfg: bool = False,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
negative_global_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
prepend_cond=None,
|
||||
**kwargs):
|
||||
|
||||
return self.model(x, t, cond = input_concat_cond)
|
||||
|
||||
class DiffusionAttnUnet1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels = 2,
|
||||
depth=14,
|
||||
n_attn_layers = 6,
|
||||
channels = [128, 128, 256, 256] + [512] * 10,
|
||||
cond_dim = 0,
|
||||
cond_noise_aug = False,
|
||||
kernel_size = 5,
|
||||
learned_resample = False,
|
||||
strides = [2] * 13,
|
||||
conv_bias = True,
|
||||
use_snake = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_noise_aug = cond_noise_aug
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
if self.cond_noise_aug:
|
||||
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
||||
|
||||
self.timestep_embed = FourierFeatures(1, 16)
|
||||
|
||||
attn_layer = depth - n_attn_layers
|
||||
|
||||
strides = [1] + strides
|
||||
|
||||
block = nn.Identity()
|
||||
|
||||
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
||||
|
||||
for i in range(depth, 0, -1):
|
||||
c = channels[i - 1]
|
||||
stride = strides[i-1]
|
||||
if stride > 2 and not learned_resample:
|
||||
raise ValueError("Must have stride 2 without learned resampling")
|
||||
|
||||
if i > 1:
|
||||
c_prev = channels[i - 2]
|
||||
add_attn = i >= attn_layer and n_attn_layers > 0
|
||||
block = SkipBlock(
|
||||
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
||||
conv_block(c_prev, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
block,
|
||||
conv_block(c * 2 if i != depth else c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c),
|
||||
SelfAttention1d(
|
||||
c, c // 32) if add_attn else nn.Identity(),
|
||||
conv_block(c, c, c_prev),
|
||||
SelfAttention1d(c_prev, c_prev //
|
||||
32) if add_attn else nn.Identity(),
|
||||
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
||||
block = nn.Sequential(
|
||||
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, c),
|
||||
block,
|
||||
conv_block(c * 2, c, c),
|
||||
conv_block(c, c, c),
|
||||
conv_block(c, c, io_channels, is_last=True),
|
||||
)
|
||||
self.net = block
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.net.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
||||
|
||||
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
||||
|
||||
inputs = [x, timestep_embed]
|
||||
|
||||
if cond is not None:
|
||||
if cond.shape[2] != x.shape[2]:
|
||||
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
||||
|
||||
if self.cond_noise_aug:
|
||||
# Get a random number between 0 and 1, uniformly sampled
|
||||
if cond_aug_scale is None:
|
||||
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
||||
else:
|
||||
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
||||
|
||||
# Add noise to the conditioning signal
|
||||
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
||||
|
||||
# Get embedding for noise cond level, reusing timestamp_embed
|
||||
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
||||
|
||||
inputs.append(aug_level_embed)
|
||||
|
||||
inputs.append(cond)
|
||||
|
||||
outputs = self.net(torch.cat(inputs, dim=1))
|
||||
|
||||
return outputs
|
||||
|
||||
class DiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = DiffusionTransformer(*args, **kwargs)
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
negative_input_concat_cond=None,
|
||||
global_cond=None,
|
||||
negative_global_cond=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_mask,
|
||||
negative_cross_attn_cond=negative_cross_attn_cond,
|
||||
negative_cross_attn_mask=negative_cross_attn_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
global_embed=global_cond,
|
||||
**kwargs)
|
||||
|
||||
class MMDiTWrapper(ConditionedDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
||||
|
||||
self.model = MMAudio(*args, **kwargs)
|
||||
|
||||
# with torch.no_grad():
|
||||
# for param in self.model.parameters():
|
||||
# param *= 0.5
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
clip_f,
|
||||
sync_f,
|
||||
text_f,
|
||||
inpaint_masked_input=None,
|
||||
t5_features=None,
|
||||
metaclip_global_text_features=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob: float = 0.0,
|
||||
batch_cfg: bool = True,
|
||||
rescale_cfg: bool = False,
|
||||
scale_phi: float = 0.0,
|
||||
**kwargs):
|
||||
|
||||
# breakpoint()
|
||||
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
||||
#assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
||||
|
||||
return self.model(
|
||||
latent=x,
|
||||
t=t,
|
||||
clip_f=clip_f,
|
||||
sync_f=sync_f,
|
||||
text_f=text_f,
|
||||
inpaint_masked_input=inpaint_masked_input,
|
||||
t5_features=t5_features,
|
||||
metaclip_global_text_features=metaclip_global_text_features,
|
||||
cfg_scale=cfg_scale,
|
||||
cfg_dropout_prob=cfg_dropout_prob,
|
||||
scale_phi=scale_phi,
|
||||
**kwargs)
|
||||
|
||||
class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
|
||||
"""
|
||||
A diffusion model that takes in conditioning
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
conditioner: MultiConditioner,
|
||||
io_channels,
|
||||
sample_rate,
|
||||
min_input_length: int,
|
||||
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
||||
pretransform: tp.Optional[Pretransform] = None,
|
||||
cross_attn_cond_ids: tp.List[str] = [],
|
||||
global_cond_ids: tp.List[str] = [],
|
||||
input_concat_ids: tp.List[str] = [],
|
||||
prepend_cond_ids: tp.List[str] = [],
|
||||
add_cond_ids: tp.List[str] = [],
|
||||
mm_cond_ids: tp.List[str] = [],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.conditioner = conditioner
|
||||
self.io_channels = io_channels
|
||||
self.sample_rate = sample_rate
|
||||
self.diffusion_objective = diffusion_objective
|
||||
self.pretransform = pretransform
|
||||
self.cross_attn_cond_ids = cross_attn_cond_ids
|
||||
self.global_cond_ids = global_cond_ids
|
||||
self.input_concat_ids = input_concat_ids
|
||||
self.prepend_cond_ids = prepend_cond_ids
|
||||
self.add_cond_ids = add_cond_ids
|
||||
self.min_input_length = min_input_length
|
||||
self.mm_cond_ids = mm_cond_ids
|
||||
|
||||
assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
|
||||
assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
|
||||
assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
|
||||
# assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
|
||||
|
||||
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
|
||||
assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
|
||||
cross_attention_input = None
|
||||
cross_attention_masks = None
|
||||
global_cond = None
|
||||
input_concat_cond = None
|
||||
prepend_cond = None
|
||||
prepend_cond_mask = None
|
||||
add_input = None
|
||||
inpaint_masked_input = None
|
||||
t5_features = None
|
||||
metaclip_global_text_features = None
|
||||
clip_f = conditioning_tensors["metaclip_features"]
|
||||
sync_f = conditioning_tensors["sync_features"]
|
||||
text_f = conditioning_tensors["metaclip_text_features"]
|
||||
if 'inpaint_masked_input' in conditioning_tensors.keys():
|
||||
inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
|
||||
if 't5_features' in conditioning_tensors.keys():
|
||||
t5_features = conditioning_tensors["t5_features"]
|
||||
if 'metaclip_global_text_features' in conditioning_tensors.keys():
|
||||
metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
|
||||
return {
|
||||
"clip_f": clip_f,
|
||||
"sync_f": sync_f,
|
||||
"text_f": text_f,
|
||||
"inpaint_masked_input": inpaint_masked_input,
|
||||
"t5_features": t5_features,
|
||||
"metaclip_global_text_features": metaclip_global_text_features
|
||||
}
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
||||
# breakpoint()
|
||||
# print(kwargs)
|
||||
return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
from prismaudio_core.inference.generation import generate_diffusion_cond
|
||||
return generate_diffusion_cond(self, *args, **kwargs)
|
||||
|
||||
class DiTUncondWrapper(DiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
io_channels,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
|
||||
|
||||
self.io_channels = io_channels
|
||||
|
||||
with torch.no_grad():
|
||||
for param in self.model.parameters():
|
||||
param *= 0.5
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
return self.model(x, t, **kwargs)
|
||||
|
||||
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
model_type = diffusion_uncond_config.get('type', None)
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('config', {})
|
||||
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
model = UNet1DUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**diffusion_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
|
||||
diffusion_uncond_config = config["model"]
|
||||
|
||||
|
||||
diffusion_config = diffusion_uncond_config.get('diffusion', {})
|
||||
model_type = diffusion_config.get('type', None)
|
||||
model_config = diffusion_config.get("config",{})
|
||||
assert model_type is not None, "Must specify model type in config"
|
||||
|
||||
pretransform = diffusion_uncond_config.get("pretransform", None)
|
||||
|
||||
sample_size = config.get("sample_size", None)
|
||||
assert sample_size is not None, "Must specify sample size in config"
|
||||
|
||||
sample_rate = config.get("sample_rate", None)
|
||||
assert sample_rate is not None, "Must specify sample rate in config"
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if model_type == 'DAU1d':
|
||||
|
||||
model = DiffusionAttnUnet1D(
|
||||
**model_config
|
||||
)
|
||||
|
||||
elif model_type == "adp_uncond_1d":
|
||||
|
||||
model = UNet1DUncondWrapper(
|
||||
io_channels = io_channels,
|
||||
**model_config
|
||||
)
|
||||
elif model_type == "dit":
|
||||
model = DiTUncondWrapper(
|
||||
**model_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown model type: {model_type}')
|
||||
|
||||
return DiffusionModelWrapper(model,
|
||||
io_channels=model.io_channels,
|
||||
sample_size=sample_size,
|
||||
sample_rate=sample_rate,
|
||||
pretransform=pretransform,
|
||||
min_input_length=min_input_length)
|
||||
|
||||
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
||||
|
||||
model_config = config["model"]
|
||||
|
||||
model_type = config["model_type"]
|
||||
|
||||
diffusion_config = model_config.get('diffusion', None)
|
||||
assert diffusion_config is not None, "Must specify diffusion config"
|
||||
|
||||
diffusion_model_type = diffusion_config.get('type', None)
|
||||
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
||||
|
||||
diffusion_model_config = diffusion_config.get('config', None)
|
||||
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
||||
|
||||
if diffusion_model_type == 'adp_cfg_1d':
|
||||
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'adp_1d':
|
||||
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'dit':
|
||||
diffusion_model = DiTWrapper(**diffusion_model_config)
|
||||
elif diffusion_model_type == 'mmdit':
|
||||
diffusion_model = MMDiTWrapper(**diffusion_model_config)
|
||||
|
||||
io_channels = model_config.get('io_channels', None)
|
||||
assert io_channels is not None, "Must specify io_channels in model config"
|
||||
|
||||
sample_rate = config.get('sample_rate', None)
|
||||
assert sample_rate is not None, "Must specify sample_rate in config"
|
||||
|
||||
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
||||
|
||||
conditioning_config = model_config.get('conditioning', None)
|
||||
|
||||
conditioner = None
|
||||
if conditioning_config is not None:
|
||||
conditioner = _get_create_multi_conditioner_from_conditioning_config()(conditioning_config)
|
||||
|
||||
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
||||
add_cond_ids = diffusion_config.get('add_cond_ids', [])
|
||||
sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
|
||||
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
||||
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
||||
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
||||
mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
|
||||
zero_init = diffusion_config.get('zero_init', False)
|
||||
pretransform = model_config.get("pretransform", None)
|
||||
|
||||
if pretransform is not None:
|
||||
pretransform = _get_create_pretransform_from_config()(pretransform, sample_rate)
|
||||
min_input_length = pretransform.downsampling_ratio
|
||||
else:
|
||||
min_input_length = 1
|
||||
|
||||
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
||||
min_input_length *= np.prod(diffusion_model_config["factors"])
|
||||
elif diffusion_model_type == "dit":
|
||||
min_input_length *= diffusion_model.model.patch_size
|
||||
|
||||
# Get the proper wrapper class
|
||||
|
||||
extra_kwargs = {}
|
||||
|
||||
if model_type == "mm_diffusion_cond":
|
||||
wrapper_fn = MMConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
extra_kwargs["mm_cond_ids"] = mm_cond_ids
|
||||
|
||||
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
|
||||
wrapper_fn = ConditionedDiffusionModelWrapper
|
||||
extra_kwargs["diffusion_objective"] = diffusion_objective
|
||||
|
||||
elif model_type == "diffusion_prior":
|
||||
prior_type = model_config.get("prior_type", None)
|
||||
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
||||
|
||||
if prior_type == "mono_stereo":
|
||||
from .diffusion_prior import MonoToStereoDiffusionPrior
|
||||
wrapper_fn = MonoToStereoDiffusionPrior
|
||||
|
||||
return wrapper_fn(
|
||||
diffusion_model,
|
||||
conditioner,
|
||||
min_input_length=min_input_length,
|
||||
sample_rate=sample_rate,
|
||||
cross_attn_cond_ids=cross_attention_ids,
|
||||
global_cond_ids=global_cond_ids,
|
||||
input_concat_ids=input_concat_ids,
|
||||
prepend_cond_ids=prepend_cond_ids,
|
||||
add_cond_ids=add_cond_ids,
|
||||
sync_cond_ids=sync_cond_ids,
|
||||
pretransform=pretransform,
|
||||
io_channels=io_channels,
|
||||
zero_init=zero_init,
|
||||
**extra_kwargs
|
||||
)
|
||||
@@ -0,0 +1,541 @@
|
||||
import typing as tp
|
||||
import math
|
||||
import torch
|
||||
# from beartype.typing import Tuple
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from .blocks import FourierFeatures
|
||||
from .transformer import ContinuousTransformer
|
||||
from .utils import mask_from_frac_lengths, resample
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
io_channels=32,
|
||||
patch_size=1,
|
||||
embed_dim=768,
|
||||
cond_token_dim=0,
|
||||
project_cond_tokens=True,
|
||||
global_cond_dim=0,
|
||||
project_global_cond=True,
|
||||
input_concat_dim=0,
|
||||
prepend_cond_dim=0,
|
||||
cond_ctx_dim=0,
|
||||
depth=12,
|
||||
num_heads=8,
|
||||
transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
|
||||
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
||||
timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
|
||||
add_token_dim=0,
|
||||
sync_token_dim=0,
|
||||
use_mlp=False,
|
||||
use_zero_init=False,
|
||||
**kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cond_token_dim = cond_token_dim
|
||||
|
||||
# Timestep embeddings
|
||||
timestep_features_dim = 256
|
||||
# Timestep embeddings
|
||||
self.timestep_cond_type = timestep_cond_type
|
||||
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
||||
|
||||
if timestep_cond_type == "global":
|
||||
timestep_embed_dim = embed_dim
|
||||
elif timestep_cond_type == "input_concat":
|
||||
assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
|
||||
input_concat_dim += timestep_embed_dim
|
||||
|
||||
self.to_timestep_embed = nn.Sequential(
|
||||
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=True),
|
||||
)
|
||||
self.use_mlp = use_mlp
|
||||
if cond_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_cond_embed = nn.Sequential(
|
||||
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
cond_embed_dim = 0
|
||||
|
||||
if global_cond_dim > 0:
|
||||
# Global conditioning
|
||||
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
||||
self.to_global_embed = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
|
||||
)
|
||||
if add_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_add_embed = nn.Sequential(
|
||||
nn.Linear(add_token_dim, add_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(add_embed_dim, add_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
add_embed_dim = 0
|
||||
|
||||
if sync_token_dim > 0:
|
||||
# Conditioning tokens
|
||||
sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
|
||||
self.to_sync_embed = nn.Sequential(
|
||||
nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
|
||||
)
|
||||
else:
|
||||
sync_embed_dim = 0
|
||||
|
||||
|
||||
if prepend_cond_dim > 0:
|
||||
# Prepend conditioning
|
||||
self.to_prepend_embed = nn.Sequential(
|
||||
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
)
|
||||
|
||||
self.input_concat_dim = input_concat_dim
|
||||
|
||||
dim_in = io_channels + self.input_concat_dim
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Transformer
|
||||
|
||||
self.transformer_type = transformer_type
|
||||
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
|
||||
self.global_cond_type = global_cond_type
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
|
||||
global_dim = None
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
# The global conditioning is projected to the embed_dim already at this point
|
||||
global_dim = embed_dim
|
||||
|
||||
self.transformer = ContinuousTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
dim_heads=embed_dim // num_heads,
|
||||
dim_in=dim_in * patch_size,
|
||||
dim_out=io_channels * patch_size,
|
||||
cross_attend = cond_token_dim > 0,
|
||||
cond_token_dim = cond_embed_dim,
|
||||
global_cond_dim=global_dim,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
||||
|
||||
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
||||
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
||||
nn.init.zeros_(self.preprocess_conv.weight)
|
||||
nn.init.zeros_(self.postprocess_conv.weight)
|
||||
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
# if isinstance(module, nn.Conv1d):
|
||||
# if module.bias is not None:
|
||||
# nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
|
||||
nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
if self.global_cond_type == "adaLN":
|
||||
for block in self.transformer.layers:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
nn.init.constant_(self.empty_clip_feat, 0)
|
||||
nn.init.constant_(self.empty_sync_feat, 0)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
mask=None,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
add_masks=None,
|
||||
sync_cond=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
||||
|
||||
if global_embed is not None:
|
||||
# Project the global conditioning to the embedding dimension
|
||||
global_embed = self.to_global_embed(global_embed)
|
||||
|
||||
prepend_inputs = None
|
||||
prepend_mask = None
|
||||
prepend_length = 0
|
||||
if prepend_cond is not None:
|
||||
# Project the prepend conditioning to the embedding dimension
|
||||
prepend_cond = self.to_prepend_embed(prepend_cond)
|
||||
|
||||
prepend_inputs = prepend_cond
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_mask = prepend_cond_mask
|
||||
|
||||
if input_concat_cond is not None:
|
||||
# reshape from (b, n, c) to (b, c, n)
|
||||
if input_concat_cond.shape[1] != x.shape[1]:
|
||||
input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# Interpolate input_concat_cond to the same length as x
|
||||
# if input_concat_cond.shape[1] != x.shape[2]:
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
||||
# input_concat_cond = input_concat_cond.transpose(1,2)
|
||||
# if len(global_embed.shape) == 2:
|
||||
# global_embed = global_embed.unsqueeze(1)
|
||||
# global_embed = global_embed + input_concat_cond
|
||||
x = torch.cat([x, input_concat_cond], dim=1)
|
||||
|
||||
# Get the batch of timestep embeddings
|
||||
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
|
||||
# import ipdb
|
||||
# ipdb.set_trace()
|
||||
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
||||
if self.timestep_cond_type == "global":
|
||||
if global_embed is not None:
|
||||
if len(global_embed.shape) == 3:
|
||||
timestep_embed = timestep_embed.unsqueeze(1)
|
||||
global_embed = global_embed + timestep_embed
|
||||
else:
|
||||
global_embed = timestep_embed
|
||||
elif self.timestep_cond_type == "input_concat":
|
||||
x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
|
||||
|
||||
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
||||
if self.global_cond_type == "prepend" and global_embed is not None:
|
||||
if prepend_inputs is None:
|
||||
# Prepend inputs are just the global embed, and the mask is all ones
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = global_embed.unsqueeze(1)
|
||||
else:
|
||||
prepend_inputs = global_embed
|
||||
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
|
||||
else:
|
||||
# Prepend inputs are the prepend conditioning + the global embed
|
||||
if len(global_embed.shape) == 2:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
|
||||
else:
|
||||
prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
|
||||
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
|
||||
|
||||
prepend_length = prepend_inputs.shape[1]
|
||||
|
||||
x = self.preprocess_conv(x) + x
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
|
||||
extra_args = {}
|
||||
|
||||
if self.global_cond_type == "adaLN":
|
||||
extra_args["global_cond"] = global_embed
|
||||
|
||||
if self.patch_size > 1:
|
||||
b, seq_len, c = x.shape
|
||||
|
||||
# 计算需要填充的数量
|
||||
pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
|
||||
|
||||
if pad_amount > 0:
|
||||
# 在时间维度上进行填充
|
||||
x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
|
||||
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
||||
|
||||
if add_cond is not None:
|
||||
# Interpolate add_cond to the same length as x
|
||||
# if self.use_mlp:
|
||||
add_cond = self.to_add_embed(add_cond)
|
||||
if add_cond.shape[1] != x.shape[1]:
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
|
||||
add_cond = add_cond.transpose(1,2)
|
||||
# add_cond = resample(add_cond, x)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = self.to_sync_embed(sync_cond)
|
||||
|
||||
if self.transformer_type == "continuous_transformer":
|
||||
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
|
||||
|
||||
if return_info:
|
||||
output, info = output
|
||||
|
||||
output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
|
||||
|
||||
if self.patch_size > 1:
|
||||
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
||||
# 移除之前添加的填充
|
||||
if pad_amount > 0:
|
||||
output = output[:, :, :seq_len]
|
||||
|
||||
output = self.postprocess_conv(output) + output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=None,
|
||||
cross_attn_cond_mask=None,
|
||||
negative_cross_attn_cond=None,
|
||||
negative_cross_attn_mask=None,
|
||||
input_concat_cond=None,
|
||||
global_embed=None,
|
||||
negative_global_embed=None,
|
||||
prepend_cond=None,
|
||||
prepend_cond_mask=None,
|
||||
add_cond=None,
|
||||
sync_cond=None,
|
||||
cfg_scale=1.0,
|
||||
cfg_dropout_prob=0.0,
|
||||
causal=False,
|
||||
scale_phi=0.0,
|
||||
mask=None,
|
||||
return_info=False,
|
||||
**kwargs):
|
||||
|
||||
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
||||
bsz, a, b = x.shape
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
t = t.to(model_dtype)
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
cross_attn_cond = cross_attn_cond.to(model_dtype)
|
||||
|
||||
if negative_cross_attn_cond is not None:
|
||||
negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
|
||||
|
||||
if input_concat_cond is not None:
|
||||
input_concat_cond = input_concat_cond.to(model_dtype)
|
||||
|
||||
if global_embed is not None:
|
||||
global_embed = global_embed.to(model_dtype)
|
||||
|
||||
if negative_global_embed is not None:
|
||||
negative_global_embed = negative_global_embed.to(model_dtype)
|
||||
|
||||
if prepend_cond is not None:
|
||||
prepend_cond = prepend_cond.to(model_dtype)
|
||||
|
||||
if add_cond is not None:
|
||||
add_cond = add_cond.to(model_dtype)
|
||||
|
||||
if sync_cond is not None:
|
||||
sync_cond = sync_cond.to(model_dtype)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
||||
|
||||
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
prepend_cond_mask = prepend_cond_mask.bool()
|
||||
|
||||
|
||||
# CFG dropout
|
||||
if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
|
||||
if cross_attn_cond is not None:
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
|
||||
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
||||
|
||||
if prepend_cond is not None:
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
|
||||
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
||||
|
||||
if add_cond is not None:
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
|
||||
add_cond = torch.where(dropout_mask, null_embed, add_cond)
|
||||
|
||||
if sync_cond is not None:
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
|
||||
sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
|
||||
|
||||
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
|
||||
# Classifier-free guidance
|
||||
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
||||
batch_inputs = torch.cat([x, x], dim=0)
|
||||
batch_timestep = torch.cat([t, t], dim=0)
|
||||
if global_embed is not None and global_embed.shape[0] == bsz:
|
||||
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
||||
elif global_embed is not None:
|
||||
batch_global_cond = global_embed
|
||||
else:
|
||||
batch_global_cond = None
|
||||
|
||||
if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
|
||||
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
|
||||
elif input_concat_cond is not None:
|
||||
batch_input_concat_cond = input_concat_cond
|
||||
else:
|
||||
batch_input_concat_cond = None
|
||||
|
||||
batch_cond = None
|
||||
batch_cond_masks = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
|
||||
|
||||
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
||||
if negative_cross_attn_cond is not None:
|
||||
|
||||
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
||||
if negative_cross_attn_mask is not None:
|
||||
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
|
||||
|
||||
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
|
||||
|
||||
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
|
||||
|
||||
else:
|
||||
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
||||
|
||||
if cross_attn_cond_mask is not None:
|
||||
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
|
||||
elif cross_attn_cond is not None:
|
||||
batch_cond = cross_attn_cond
|
||||
else:
|
||||
batch_cond = None
|
||||
|
||||
batch_prepend_cond = None
|
||||
batch_prepend_cond_mask = None
|
||||
|
||||
if prepend_cond is not None and prepend_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
||||
|
||||
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
||||
|
||||
if prepend_cond_mask is not None:
|
||||
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
|
||||
elif prepend_cond is not None:
|
||||
batch_prepend_cond = prepend_cond
|
||||
else:
|
||||
batch_prepend_cond = None
|
||||
|
||||
batch_add_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if add_cond is not None and add_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(add_cond, device=add_cond.device)
|
||||
|
||||
|
||||
batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
|
||||
elif add_cond is not None:
|
||||
batch_add_cond = add_cond
|
||||
else:
|
||||
batch_add_cond = None
|
||||
|
||||
batch_sync_cond = None
|
||||
|
||||
# Handle CFG for cross-attention conditioning
|
||||
if sync_cond is not None and sync_cond.shape[0] == bsz:
|
||||
|
||||
null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
|
||||
|
||||
|
||||
batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
|
||||
elif sync_cond is not None:
|
||||
batch_sync_cond = sync_cond
|
||||
else:
|
||||
batch_sync_cond = None
|
||||
|
||||
if mask is not None:
|
||||
batch_masks = torch.cat([mask, mask], dim=0)
|
||||
else:
|
||||
batch_masks = None
|
||||
|
||||
batch_output = self._forward(
|
||||
batch_inputs,
|
||||
batch_timestep,
|
||||
cross_attn_cond=batch_cond,
|
||||
cross_attn_cond_mask=batch_cond_masks,
|
||||
mask = batch_masks,
|
||||
input_concat_cond=batch_input_concat_cond,
|
||||
global_embed = batch_global_cond,
|
||||
prepend_cond = batch_prepend_cond,
|
||||
prepend_cond_mask = batch_prepend_cond_mask,
|
||||
add_cond = batch_add_cond,
|
||||
sync_cond = batch_sync_cond,
|
||||
return_info = return_info,
|
||||
**kwargs)
|
||||
|
||||
if return_info:
|
||||
batch_output, info = batch_output
|
||||
|
||||
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
||||
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
||||
|
||||
# CFG Rescale
|
||||
if scale_phi != 0.0:
|
||||
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
||||
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
||||
output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
|
||||
else:
|
||||
output = cfg_output
|
||||
|
||||
if return_info:
|
||||
return output, info
|
||||
|
||||
return output
|
||||
|
||||
else:
|
||||
return self._forward(
|
||||
x,
|
||||
t,
|
||||
cross_attn_cond=cross_attn_cond,
|
||||
cross_attn_cond_mask=cross_attn_cond_mask,
|
||||
input_concat_cond=input_concat_cond,
|
||||
global_embed=global_embed,
|
||||
prepend_cond=prepend_cond,
|
||||
prepend_cond_mask=prepend_cond_mask,
|
||||
add_cond=add_cond,
|
||||
sync_cond=sync_cond,
|
||||
mask=mask,
|
||||
return_info=return_info,
|
||||
**kwargs
|
||||
)
|
||||
@@ -0,0 +1,278 @@
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from .blocks import AdaRMSNorm
|
||||
from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
|
||||
class ContinuousLocalTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
causal = False,
|
||||
local_attn_window_size = 64,
|
||||
heads = 8,
|
||||
ff_mult = 2,
|
||||
cond_dim = 0,
|
||||
cross_attn_cond_dim = 0,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_head = dim//heads
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
|
||||
|
||||
self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
|
||||
|
||||
self.local_attn_window_size = local_attn_window_size
|
||||
|
||||
self.cond_dim = cond_dim
|
||||
|
||||
self.cross_attn_cond_dim = cross_attn_cond_dim
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
|
||||
|
||||
for _ in range(depth):
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
causal=causal,
|
||||
zero_init_output=True,
|
||||
natten_kernel_size=local_attn_window_size,
|
||||
),
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_heads=dim_head,
|
||||
dim_context = cross_attn_cond_dim,
|
||||
zero_init_output=True
|
||||
) if self.cross_attn_cond_dim > 0 else nn.Identity(),
|
||||
AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
|
||||
FeedForward(dim = dim, mult = ff_mult, no_bias=True)
|
||||
]))
|
||||
|
||||
def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
if prepend_cond is not None:
|
||||
x = torch.cat([prepend_cond, x], dim=1)
|
||||
|
||||
pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
|
||||
for attn_norm, attn, xattn, ff_norm, ff in self.layers:
|
||||
|
||||
residual = x
|
||||
if cond is not None:
|
||||
x = checkpoint(attn_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(attn_norm, x)
|
||||
|
||||
x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
|
||||
|
||||
if cross_attn_cond is not None:
|
||||
x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
|
||||
|
||||
residual = x
|
||||
|
||||
if cond is not None:
|
||||
x = checkpoint(ff_norm, x, cond)
|
||||
else:
|
||||
x = checkpoint(ff_norm, x)
|
||||
|
||||
x = checkpoint(ff, x) + residual
|
||||
|
||||
return checkpoint(self.project_out, x)
|
||||
|
||||
class TransformerDownsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim = 768,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
downsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.downsample_ratio = downsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size=local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
# Trade sequence length for channels
|
||||
x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
|
||||
|
||||
# Project back to embed dim
|
||||
x = checkpoint(self.project_down, x)
|
||||
|
||||
return x
|
||||
|
||||
class TransformerUpsampleBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
embed_dim,
|
||||
depth = 3,
|
||||
heads = 12,
|
||||
upsample_ratio = 2,
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.upsample_ratio = upsample_ratio
|
||||
|
||||
self.transformer = ContinuousLocalTransformer(
|
||||
dim=embed_dim,
|
||||
depth=depth,
|
||||
heads=heads,
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
|
||||
|
||||
self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Project to embed dim
|
||||
x = checkpoint(self.project_in, x)
|
||||
|
||||
# Project to increase channel dim
|
||||
x = checkpoint(self.project_up, x)
|
||||
|
||||
# Trade channels for sequence length
|
||||
x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
|
||||
|
||||
# Compute
|
||||
x = self.transformer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [96, 192, 384, 768],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerDownsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
downsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoder1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dims = [768, 384, 192, 96],
|
||||
heads = [12, 12, 12, 12],
|
||||
depths = [3, 3, 3, 3],
|
||||
ratios = [2, 2, 2, 2],
|
||||
local_attn_window_size = 64,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
layers = []
|
||||
|
||||
for layer in range(len(depths)):
|
||||
prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
|
||||
|
||||
layers.append(
|
||||
TransformerUpsampleBlock1D(
|
||||
in_channels = prev_dim,
|
||||
embed_dim = embed_dims[layer],
|
||||
heads = heads[layer],
|
||||
depth = depths[layer],
|
||||
upsample_ratio = ratios[layer],
|
||||
local_attn_window_size = local_attn_window_size,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
|
||||
self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = rearrange(x, "b c n -> b n c")
|
||||
x = checkpoint(self.project_in, x)
|
||||
x = self.layers(x)
|
||||
x = checkpoint(self.project_out, x)
|
||||
x = rearrange(x, "b n c -> b c n")
|
||||
return x
|
||||
@@ -0,0 +1 @@
|
||||
# mmmodules package
|
||||
@@ -0,0 +1 @@
|
||||
# mmmodules.model package
|
||||
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ChannelLastConv1d(nn.Conv1d):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.permute(0, 2, 1)
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
# https://github.com/Stability-AI/sd3-ref
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class ConvMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int = 256,
|
||||
kernel_size: int = 3,
|
||||
padding: int = 1,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w2 = ChannelLastConv1d(hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.w3 = ChannelLastConv1d(dim,
|
||||
hidden_dim,
|
||||
bias=False,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
@@ -0,0 +1,393 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from scipy.optimize import fmin
|
||||
from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord
|
||||
|
||||
class PQMF(nn.Module):
|
||||
"""
|
||||
Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction.
|
||||
Uses polyphase representation which is computationally more efficient for real-time.
|
||||
|
||||
Parameters:
|
||||
- attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB.
|
||||
- num_bands (int): Number of desired frequency bands. It must be a power of 2.
|
||||
"""
|
||||
|
||||
def __init__(self, attenuation, num_bands):
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# Ensure num_bands is a power of 2
|
||||
is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands)))
|
||||
assert is_power_of_2, "'num_bands' must be a power of 2."
|
||||
|
||||
# Create the prototype filter
|
||||
prototype_filter = design_prototype_filter(attenuation, num_bands)
|
||||
filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands)
|
||||
padded_filter_bank = pad_to_nearest_power_of_two(filter_bank)
|
||||
|
||||
# Register filters and settings
|
||||
self.register_buffer("filter_bank", padded_filter_bank)
|
||||
self.register_buffer("prototype", prototype_filter)
|
||||
self.num_bands = num_bands
|
||||
|
||||
def forward(self, signal):
|
||||
"""Decompose the signal into multiple frequency bands."""
|
||||
# If signal is not a pytorch tensor of Batch x Channels x Length, convert it
|
||||
signal = prepare_signal_dimensions(signal)
|
||||
# The signal length must be a multiple of num_bands. Pad it with zeros.
|
||||
signal = pad_signal(signal, self.num_bands)
|
||||
# run it
|
||||
signal = polyphase_analysis(signal, self.filter_bank)
|
||||
return apply_alias_cancellation(signal)
|
||||
|
||||
def inverse(self, bands):
|
||||
"""Reconstruct the original signal from the frequency bands."""
|
||||
bands = apply_alias_cancellation(bands)
|
||||
return polyphase_synthesis(bands, self.filter_bank)
|
||||
|
||||
|
||||
def prepare_signal_dimensions(signal):
|
||||
"""
|
||||
Rearrange signal into Batch x Channels x Length.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor or numpy.ndarray
|
||||
The input signal.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Preprocessed signal tensor.
|
||||
"""
|
||||
# Convert numpy to torch tensor
|
||||
if isinstance(signal, np.ndarray):
|
||||
signal = torch.from_numpy(signal)
|
||||
|
||||
# Ensure tensor
|
||||
if not isinstance(signal, torch.Tensor):
|
||||
raise ValueError("Input should be either a numpy array or a PyTorch tensor.")
|
||||
|
||||
# Modify dimension of signal to Batch x Channels x Length
|
||||
if signal.dim() == 1:
|
||||
# This is just a mono signal. Unsqueeze to 1 x 1 x Length
|
||||
signal = signal.unsqueeze(0).unsqueeze(0)
|
||||
elif signal.dim() == 2:
|
||||
# This is a multi-channel signal (e.g. stereo)
|
||||
# Rearrange so that larger dimension (Length) is last
|
||||
if signal.shape[0] > signal.shape[1]:
|
||||
signal = signal.T
|
||||
# Unsqueeze to 1 x Channels x Length
|
||||
signal = signal.unsqueeze(0)
|
||||
return signal
|
||||
|
||||
def pad_signal(signal, num_bands):
|
||||
"""
|
||||
Pads the signal to make its length divisible by the given number of bands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
The input signal tensor, where the last dimension represents the signal length.
|
||||
|
||||
num_bands : int
|
||||
The number of bands by which the signal length should be divisible.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The padded signal tensor. If the original signal length was already divisible
|
||||
by num_bands, returns the original signal unchanged.
|
||||
"""
|
||||
remainder = signal.shape[-1] % num_bands
|
||||
if remainder > 0:
|
||||
padding_size = num_bands - remainder
|
||||
signal = nn.functional.pad(signal, (0, padding_size))
|
||||
return signal
|
||||
|
||||
def generate_modulated_filter_bank(prototype_filter, num_bands):
|
||||
"""
|
||||
Generate a QMF bank of cosine modulated filters based on a given prototype filter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prototype_filter : torch.Tensor
|
||||
The prototype filter used as the basis for modulation.
|
||||
num_bands : int
|
||||
The number of desired subbands or filters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A bank of cosine modulated filters.
|
||||
"""
|
||||
|
||||
# Initialize indices for modulation.
|
||||
subband_indices = torch.arange(num_bands).reshape(-1, 1)
|
||||
|
||||
# Calculate the length of the prototype filter.
|
||||
filter_length = prototype_filter.shape[-1]
|
||||
|
||||
# Generate symmetric time indices centered around zero.
|
||||
time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1)
|
||||
|
||||
# Calculate phase offsets to ensure orthogonality between subbands.
|
||||
phase_offsets = (-1)**subband_indices * np.pi / 4
|
||||
|
||||
# Compute the cosine modulation function.
|
||||
modulation = torch.cos(
|
||||
(2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets
|
||||
)
|
||||
|
||||
# Apply modulation to the prototype filter.
|
||||
modulated_filters = 2 * prototype_filter * modulation
|
||||
|
||||
return modulated_filters
|
||||
|
||||
|
||||
def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None):
|
||||
"""
|
||||
Design a lowpass filter using the Kaiser window.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
The angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
The desired stopband attenuation in decibels (dB).
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The designed lowpass filter coefficients.
|
||||
"""
|
||||
|
||||
estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi)
|
||||
|
||||
# Ensure the estimated length is odd.
|
||||
estimated_length = 2 * (estimated_length // 2) + 1
|
||||
|
||||
if filter_length is None:
|
||||
filter_length = estimated_length
|
||||
|
||||
return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi)
|
||||
|
||||
|
||||
def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length):
|
||||
"""
|
||||
Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angular_cutoff : float
|
||||
Angular frequency cutoff of the filter.
|
||||
attenuation : float
|
||||
Desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The computed objective (loss) value for the given filter specs.
|
||||
"""
|
||||
|
||||
filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length)
|
||||
convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full")
|
||||
|
||||
return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:]))
|
||||
|
||||
|
||||
def design_prototype_filter(attenuation, num_bands, filter_length=None):
|
||||
"""
|
||||
Design the optimal prototype filter for a multiband system given the desired specs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attenuation : float
|
||||
The desired stopband attenuation in dB.
|
||||
num_bands : int
|
||||
Number of bands for the multiband filter system.
|
||||
filter_length : int, optional
|
||||
Desired length of the filter. If not provided, it's computed based on the given specs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
The optimal prototype filter coefficients.
|
||||
"""
|
||||
|
||||
optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length),
|
||||
1 / num_bands, disp=0)[0]
|
||||
|
||||
prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length)
|
||||
return torch.tensor(prototype_filter, dtype=torch.float32)
|
||||
|
||||
def pad_to_nearest_power_of_two(x):
|
||||
"""
|
||||
Pads the input tensor 'x' on both sides such that its last dimension
|
||||
becomes the nearest larger power of two.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor to be padded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The padded tensor.
|
||||
"""
|
||||
current_length = x.shape[-1]
|
||||
target_length = 2**math.ceil(math.log2(current_length))
|
||||
|
||||
total_padding = target_length - current_length
|
||||
left_padding = total_padding // 2
|
||||
right_padding = total_padding - left_padding
|
||||
|
||||
return nn.functional.pad(x, (left_padding, right_padding))
|
||||
|
||||
def apply_alias_cancellation(x):
|
||||
"""
|
||||
Applies alias cancellation by inverting the sign of every
|
||||
second element of every second row, starting from the second
|
||||
row's first element in a tensor.
|
||||
|
||||
This operation helps ensure that the aliasing introduced in
|
||||
each band during the decomposition will be counteracted during
|
||||
the reconstruction.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
x : torch.Tensor
|
||||
The input tensor.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Tensor with specific elements' sign inverted for alias cancellation.
|
||||
"""
|
||||
|
||||
# Create a mask of the same shape as 'x', initialized with all ones
|
||||
mask = torch.ones_like(x)
|
||||
|
||||
# Update specific elements in the mask to -1 to perform inversion
|
||||
mask[..., 1::2, ::2] = -1
|
||||
|
||||
# Apply the mask to the input tensor 'x'
|
||||
return x * mask
|
||||
|
||||
def ensure_odd_length(tensor):
|
||||
"""
|
||||
Pads the last dimension of a tensor to ensure its size is odd.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor whose last dimension might need padding.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
The original tensor if its last dimension was already odd,
|
||||
or the padded tensor with an odd-sized last dimension.
|
||||
"""
|
||||
|
||||
last_dim_size = tensor.shape[-1]
|
||||
|
||||
if last_dim_size % 2 == 0:
|
||||
tensor = nn.functional.pad(tensor, (0, 1))
|
||||
|
||||
return tensor
|
||||
|
||||
def polyphase_analysis(signal, filter_bank):
|
||||
"""
|
||||
Applies the polyphase method to efficiently analyze the signal using a filter bank.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
signal : torch.Tensor
|
||||
Input signal tensor with shape (Batch x Channels x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Filter bank tensor with shape (Bands x Length).
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Signal split into sub-bands. (Batch x Channels x Bands x Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange signal for polyphase processing.
|
||||
# Also combine Batch x Channel into one dimension for now.
|
||||
#signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands)
|
||||
signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands)
|
||||
|
||||
# Rearrange the filter bank for matching signal shape
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands)
|
||||
|
||||
# Apply convolution with appropriate padding to maintain spatial dimensions
|
||||
padding = filter_bank.shape[-1] // 2
|
||||
filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding)
|
||||
|
||||
# Truncate the last dimension post-convolution to adjust the output shape
|
||||
filtered_signal = filtered_signal[..., :-1]
|
||||
# Rearrange the first dimension back into Batch x Channels
|
||||
filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels)
|
||||
|
||||
return filtered_signal
|
||||
|
||||
def polyphase_synthesis(signal, filter_bank):
|
||||
"""
|
||||
Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal : torch.Tensor
|
||||
Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length).
|
||||
|
||||
filter_bank : torch.Tensor
|
||||
Analysis filter bank (shape: Bands x Length).
|
||||
|
||||
should_rearrange : bool, optional
|
||||
Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Reconstructed signal (shape: Batch x Channels X Length)
|
||||
"""
|
||||
|
||||
num_bands = filter_bank.shape[0]
|
||||
num_channels = signal.shape[1]
|
||||
|
||||
# Rearrange the filter bank
|
||||
filter_bank = filter_bank.flip(-1)
|
||||
filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands)
|
||||
|
||||
# Combine Batch x Channels into one dimension for now.
|
||||
signal = rearrange(signal, "b c n t -> (b c) n t")
|
||||
|
||||
# Apply convolution with appropriate padding
|
||||
padding_amount = filter_bank.shape[-1] // 2 + 1
|
||||
reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount))
|
||||
|
||||
# Scale the result
|
||||
reconstructed_signal = reconstructed_signal[..., :-1] * num_bands
|
||||
|
||||
# Reorganize the output and truncate
|
||||
reconstructed_signal = reconstructed_signal.flip(1)
|
||||
reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands)
|
||||
reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:]
|
||||
|
||||
return reconstructed_signal
|
||||
@@ -0,0 +1,258 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
class Pretransform(nn.Module):
|
||||
def __init__(self, enable_grad, io_channels, is_discrete):
|
||||
super().__init__()
|
||||
|
||||
self.is_discrete = is_discrete
|
||||
self.io_channels = io_channels
|
||||
self.encoded_channels = None
|
||||
self.downsampling_ratio = None
|
||||
|
||||
self.enable_grad = enable_grad
|
||||
|
||||
def encode(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, z):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
class AutoencoderPretransform(Pretransform):
|
||||
def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
|
||||
super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
|
||||
self.model = model
|
||||
self.model.requires_grad_(False).eval()
|
||||
self.scale=scale
|
||||
self.downsampling_ratio = model.downsampling_ratio
|
||||
self.io_channels = model.io_channels
|
||||
self.sample_rate = model.sample_rate
|
||||
|
||||
self.model_half = model_half
|
||||
self.iterate_batch = iterate_batch
|
||||
|
||||
self.encoded_channels = model.latent_dim
|
||||
|
||||
self.chunked = chunked
|
||||
self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
|
||||
|
||||
if self.model_half:
|
||||
self.model.half()
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
if self.model_half:
|
||||
x = x.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
encoded = encoded.float()
|
||||
|
||||
return encoded / self.scale
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = z * self.scale
|
||||
|
||||
if self.model_half:
|
||||
z = z.half()
|
||||
self.model.to(torch.float16)
|
||||
|
||||
decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
|
||||
|
||||
if self.model_half:
|
||||
decoded = decoded.float()
|
||||
|
||||
return decoded
|
||||
|
||||
def tokenize(self, x, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
||||
|
||||
_, info = self.model.encode(x, return_info = True, **kwargs)
|
||||
|
||||
return info[self.model.bottleneck.tokens_id]
|
||||
|
||||
def decode_tokens(self, tokens, **kwargs):
|
||||
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
||||
|
||||
return self.model.decode_tokens(tokens, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
self.model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
class WaveletPretransform(Pretransform):
|
||||
def __init__(self, channels, levels, wavelet):
|
||||
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
|
||||
|
||||
from .wavelets import WaveletEncode1d, WaveletDecode1d
|
||||
|
||||
self.encoder = WaveletEncode1d(channels, levels, wavelet)
|
||||
self.decoder = WaveletDecode1d(channels, levels, wavelet)
|
||||
|
||||
self.downsampling_ratio = 2 ** levels
|
||||
self.io_channels = channels
|
||||
self.encoded_channels = channels * self.downsampling_ratio
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(z)
|
||||
|
||||
class PQMFPretransform(Pretransform):
|
||||
def __init__(self, attenuation=100, num_bands=16):
|
||||
# TODO: Fix PQMF to take in in-channels
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
||||
from .pqmf import PQMF
|
||||
self.pqmf = PQMF(attenuation, num_bands)
|
||||
|
||||
|
||||
def encode(self, x):
|
||||
# x is (Batch x Channels x Time)
|
||||
x = self.pqmf.forward(x)
|
||||
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
||||
# but Pretransform needs Batch x Channels x Time
|
||||
# so concatenate channels and bands into one axis
|
||||
return rearrange(x, "b c n t -> b (c n) t")
|
||||
|
||||
def decode(self, x):
|
||||
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
||||
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
||||
# returns (Batch x Channels x Time)
|
||||
return self.pqmf.inverse(x)
|
||||
|
||||
class PretrainedDACPretransform(Pretransform):
|
||||
def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
import dac
|
||||
|
||||
model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
|
||||
|
||||
self.model = dac.DAC.load(model_path)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
if model_type == "44khz":
|
||||
self.downsampling_ratio = 512
|
||||
else:
|
||||
self.downsampling_ratio = 320
|
||||
|
||||
self.io_channels = 1
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.chunked = chunked
|
||||
|
||||
self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.n_codebooks
|
||||
|
||||
self.codebook_size = self.model.codebook_size
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
latents = self.model.encoder(x)
|
||||
|
||||
if self.quantize_on_decode:
|
||||
output = latents
|
||||
else:
|
||||
z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
output = z
|
||||
|
||||
if self.scale != 1.0:
|
||||
output = output / self.scale
|
||||
|
||||
return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
if self.scale != 1.0:
|
||||
z = z * self.scale
|
||||
|
||||
if self.quantize_on_decode:
|
||||
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
return self.model.encode(x)[1]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
latents = self.model.quantizer.from_codes(tokens)
|
||||
return self.model.decode(latents)
|
||||
|
||||
class AudiocraftCompressionPretransform(Pretransform):
|
||||
def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
|
||||
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
||||
|
||||
try:
|
||||
from audiocraft.models import CompressionModel
|
||||
except ImportError:
|
||||
raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
|
||||
|
||||
self.model = CompressionModel.get_pretrained(model_type)
|
||||
|
||||
self.quantize_on_decode = quantize_on_decode
|
||||
|
||||
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
||||
|
||||
self.sample_rate = self.model.sample_rate
|
||||
|
||||
self.io_channels = self.model.channels
|
||||
|
||||
self.scale = scale
|
||||
|
||||
#self.encoded_channels = self.model.latent_dim
|
||||
|
||||
self.num_quantizers = self.model.num_codebooks
|
||||
|
||||
self.codebook_size = self.model.cardinality
|
||||
|
||||
self.model.to(torch.float16).eval().requires_grad_(False)
|
||||
|
||||
def encode(self, x):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous encoding"
|
||||
|
||||
# latents = self.model.encoder(x)
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# output = latents
|
||||
# else:
|
||||
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
||||
# output = z
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# output = output / self.scale
|
||||
|
||||
# return output
|
||||
|
||||
def decode(self, z):
|
||||
|
||||
assert False, "Audiocraft compression models do not support continuous decoding"
|
||||
|
||||
# if self.scale != 1.0:
|
||||
# z = z * self.scale
|
||||
|
||||
# if self.quantize_on_decode:
|
||||
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
||||
|
||||
# return self.model.decode(z)
|
||||
|
||||
def tokenize(self, x):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.encode(x.to(torch.float16))[0]
|
||||
|
||||
def decode_tokens(self, tokens):
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return self.model.decode(tokens)
|
||||
@@ -0,0 +1,993 @@
|
||||
from functools import reduce, partial
|
||||
from packaging import version
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from torch.cuda.amp import autocast
|
||||
from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
||||
from typing import Callable, Literal
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
|
||||
HAS_FLASH_ATTN = True
|
||||
except ImportError:
|
||||
HAS_FLASH_ATTN = False
|
||||
flash_attn_kvpacked_func = None
|
||||
flash_attn_func = None
|
||||
|
||||
from .utils import compile
|
||||
try:
|
||||
import natten
|
||||
except ImportError:
|
||||
natten = None
|
||||
|
||||
def checkpoint(function, *args, **kwargs):
|
||||
kwargs.setdefault("use_reentrant", False)
|
||||
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
|
||||
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
|
||||
|
||||
def create_causal_mask(i, j, device):
|
||||
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
|
||||
|
||||
def or_reduce(masks):
|
||||
head, *body = masks
|
||||
for rest in body:
|
||||
head = head | rest
|
||||
return head
|
||||
|
||||
# positional embeddings
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
||||
|
||||
pos_emb = self.emb(pos)
|
||||
pos_emb = pos_emb * self.scale
|
||||
return pos_emb
|
||||
|
||||
class ScaledSinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, dim, theta = 10000):
|
||||
super().__init__()
|
||||
assert (dim % 2) == 0, 'dimension must be divisible by 2'
|
||||
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
||||
|
||||
half_dim = dim // 2
|
||||
freq_seq = torch.arange(half_dim).float() / half_dim
|
||||
inv_freq = theta ** -freq_seq
|
||||
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
||||
|
||||
def forward(self, x, pos = None, seq_start_pos = None):
|
||||
seq_len, device = x.shape[1], x.device
|
||||
|
||||
if pos is None:
|
||||
pos = torch.arange(seq_len, device = device)
|
||||
|
||||
if seq_start_pos is not None:
|
||||
pos = pos - seq_start_pos[..., None]
|
||||
|
||||
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
|
||||
return emb * self.scale
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
use_xpos = False,
|
||||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
||||
if not use_xpos:
|
||||
self.register_buffer('scale', None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = scale_base
|
||||
self.register_buffer('scale', scale)
|
||||
|
||||
def forward_from_seq_len(self, seq_len):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = torch.arange(seq_len, device = device)
|
||||
return self.forward(t)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def forward(self, t):
|
||||
device = self.inv_freq.device
|
||||
|
||||
t = t.to(torch.float32)
|
||||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (j d) -> ... j d', j = 2)
|
||||
x1, x2 = x.unbind(dim = -2)
|
||||
return torch.cat((-x2, x1), dim = -1)
|
||||
|
||||
@autocast(enabled = False)
|
||||
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
||||
out_dtype = t.dtype
|
||||
|
||||
# cast to float32 if necessary for numerical stability
|
||||
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
|
||||
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
||||
freqs, t = freqs.to(dtype), t.to(dtype)
|
||||
freqs = freqs[-seq_len:, :]
|
||||
|
||||
if t.ndim == 4 and freqs.ndim == 3:
|
||||
freqs = rearrange(freqs, 'b n d -> b 1 n d')
|
||||
|
||||
# partial rotary embeddings, Wang et al. GPT-J
|
||||
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
|
||||
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
|
||||
|
||||
return torch.cat((t, t_unrotated), dim = -1)
|
||||
|
||||
# norms
|
||||
class DynamicTanh(nn.Module):
|
||||
def __init__(self, dim, init_alpha=10.0):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.tanh(self.alpha * x)
|
||||
return self.gamma * x + self.beta
|
||||
|
||||
class RunningInstanceNorm(nn.Module):
|
||||
def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True):
|
||||
super().__init__()
|
||||
self.register_buffer("running_mean", torch.zeros(1,1,dim))
|
||||
self.register_buffer("running_std", torch.ones(1,1,dim))
|
||||
self.saturate = saturate
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.dim = dim
|
||||
self.trainable_gain = trainable_gain
|
||||
if self.trainable_gain:
|
||||
self.gain = nn.Parameter(torch.ones(1))
|
||||
|
||||
def _update_stats(self, x):
|
||||
self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)
|
||||
self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
self._update_stats(x)
|
||||
x = (x - self.running_mean) / self.running_std
|
||||
if self.saturate:
|
||||
x = torch.asinh(x)
|
||||
if self.trainable_gain:
|
||||
x = x * self.gain
|
||||
return x
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5):
|
||||
"""
|
||||
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if fix_scale:
|
||||
self.register_buffer("gamma", torch.ones(dim))
|
||||
else:
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
|
||||
if bias:
|
||||
self.beta = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
self.eps = eps
|
||||
|
||||
self.force_fp32 = force_fp32
|
||||
|
||||
def forward(self, x):
|
||||
if not self.force_fp32:
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps)
|
||||
else:
|
||||
output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps)
|
||||
return output.to(x.dtype)
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_val = 1e-5):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.full([dim], init_val))
|
||||
def forward(self, x):
|
||||
return x * self.scale
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_in,
|
||||
dim_out,
|
||||
activation: Callable,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
|
||||
self.use_conv = use_conv
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
else:
|
||||
x = self.proj(x)
|
||||
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * self.act(gate)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out = None,
|
||||
mult = 4,
|
||||
no_bias = False,
|
||||
glu = True,
|
||||
use_conv = False,
|
||||
conv_kernel_size = 3,
|
||||
zero_init_output = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
# Default to SwiGLU
|
||||
|
||||
activation = nn.SiLU()
|
||||
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
|
||||
if glu:
|
||||
linear_in = GLU(dim, inner_dim, activation)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
|
||||
|
||||
# init last linear layer to 0
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(linear_out.weight)
|
||||
if not no_bias:
|
||||
nn.init.zeros_(linear_out.bias)
|
||||
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
dim_context = None,
|
||||
causal = False,
|
||||
zero_init_output=True,
|
||||
qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none',
|
||||
differential = False,
|
||||
feat_scale = False
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = dim_heads
|
||||
self.differential = differential
|
||||
|
||||
dim_kv = dim_context if dim_context is not None else dim
|
||||
|
||||
self.num_heads = dim // dim_heads
|
||||
self.kv_heads = dim_kv // dim_heads
|
||||
|
||||
if dim_context is not None:
|
||||
if differential:
|
||||
self.to_q = nn.Linear(dim, dim * 2, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False)
|
||||
else:
|
||||
self.to_q = nn.Linear(dim, dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
|
||||
else:
|
||||
if differential:
|
||||
self.to_qkv = nn.Linear(dim, dim * 5, bias=False)
|
||||
else:
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.to_out = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
if zero_init_output:
|
||||
nn.init.zeros_(self.to_out.weight)
|
||||
|
||||
if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']:
|
||||
raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}')
|
||||
|
||||
self.qk_norm = qk_norm
|
||||
if self.qk_norm == "ln":
|
||||
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
|
||||
elif self.qk_norm == 'rns':
|
||||
self.q_norm = nn.RMSNorm(dim_heads)
|
||||
self.k_norm = nn.RMSNorm(dim_heads)
|
||||
elif self.qk_norm == 'dyt':
|
||||
self.q_norm = DynamicTanh(dim_heads)
|
||||
self.k_norm = DynamicTanh(dim_heads)
|
||||
|
||||
self.sdp_kwargs = dict(
|
||||
enable_flash = True,
|
||||
enable_math = True,
|
||||
enable_mem_efficient = True
|
||||
)
|
||||
|
||||
self.feat_scale = feat_scale
|
||||
|
||||
if self.feat_scale:
|
||||
self.lambda_dc = nn.Parameter(torch.zeros(dim))
|
||||
self.lambda_hf = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
self.causal = causal
|
||||
if causal:
|
||||
print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.')
|
||||
|
||||
@compile
|
||||
def apply_qk_layernorm(self, q, k):
|
||||
q_type = q.dtype
|
||||
k_type = k.dtype
|
||||
q = self.q_norm(q).to(q_type)
|
||||
k = self.k_norm(k).to(k_type)
|
||||
return q, k
|
||||
|
||||
|
||||
def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None):
|
||||
|
||||
if self.num_heads != self.kv_heads:
|
||||
# Repeat interleave kv_heads to match q_heads for grouped query attention
|
||||
heads_per_kv_head = self.num_heads // self.kv_heads
|
||||
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
|
||||
|
||||
flash_attn_available = HAS_FLASH_ATTN
|
||||
if flash_attn_sliding_window is not None and (not flash_attn_available):
|
||||
print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available")
|
||||
|
||||
if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None:
|
||||
print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention")
|
||||
|
||||
if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None):
|
||||
print(f"Disabling FlexAttention because causal is set")
|
||||
flex_attention_block_mask = None
|
||||
flex_attention_score_mod = None
|
||||
|
||||
if flex_attention_block_mask is not None or flex_attention_score_mod is not None:
|
||||
out = flex_attention_compiled(q,k,v,
|
||||
block_mask = flex_attention_block_mask,
|
||||
score_mod = flex_attention_score_mod)
|
||||
elif flash_attn_available:
|
||||
fa_dtype_in = q.dtype
|
||||
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v))
|
||||
|
||||
if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16:
|
||||
q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v))
|
||||
|
||||
out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1])
|
||||
|
||||
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
|
||||
else:
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal = causal)
|
||||
return out
|
||||
|
||||
|
||||
#@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
rotary_pos_emb = None,
|
||||
causal = None,
|
||||
flex_attention_block_mask = None,
|
||||
flex_attention_score_mod = None,
|
||||
flash_attn_sliding_window = None
|
||||
):
|
||||
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
|
||||
|
||||
kv_input = context if has_context else x
|
||||
|
||||
if hasattr(self, 'to_q'):
|
||||
# Use separate linear projections for q and k/v
|
||||
if self.differential:
|
||||
q, q_diff = self.to_q(x).chunk(2, dim=-1)
|
||||
q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1)
|
||||
k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v))
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q = self.to_q(x)
|
||||
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
|
||||
else:
|
||||
# Use fused linear projection
|
||||
if self.differential:
|
||||
q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1)
|
||||
q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff))
|
||||
q = torch.stack([q, q_diff], dim = 1)
|
||||
k = torch.stack([k, k_diff], dim = 1)
|
||||
else:
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
||||
|
||||
# Normalize q and k for cosine sim attention
|
||||
if self.qk_norm == "l2":
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
elif self.qk_norm != "none":
|
||||
q, k = self.apply_qk_layernorm(q, k)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
freqs, _ = rotary_pos_emb
|
||||
q_dtype = q.dtype
|
||||
k_dtype = k.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
freqs = freqs.to(torch.float32)
|
||||
if q.shape[-2] >= k.shape[-2]:
|
||||
ratio = q.shape[-2] / k.shape[-2]
|
||||
q_freqs, k_freqs = freqs, ratio * freqs
|
||||
else:
|
||||
ratio = k.shape[-2] / q.shape[-2]
|
||||
q_freqs, k_freqs = ratio * freqs, freqs
|
||||
q = apply_rotary_pos_emb(q, q_freqs)
|
||||
k = apply_rotary_pos_emb(k, k_freqs)
|
||||
q = q.to(v.dtype)
|
||||
k = k.to(v.dtype)
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
if n == 1 and causal:
|
||||
causal = False
|
||||
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim = 1)
|
||||
k, k_diff = k.unbind(dim = 1)
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(out, ' b h n d -> b n (h d)')
|
||||
|
||||
# Communicate between heads
|
||||
|
||||
# with autocast(enabled = False):
|
||||
# out_dtype = out.dtype
|
||||
# out = out.to(torch.float32)
|
||||
# out = self.to_out(out).to(out_dtype)
|
||||
out = self.to_out(out)
|
||||
|
||||
if self.feat_scale:
|
||||
out_dc = out.mean(dim=-2, keepdim=True)
|
||||
out_hf = out - out_dc
|
||||
|
||||
# Selectively modulate DC and high frequency components
|
||||
out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf
|
||||
|
||||
return out
|
||||
|
||||
class ConformerModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
norm_kwargs = {},
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
||||
self.in_norm = LayerNorm(dim, **norm_kwargs)
|
||||
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
self.glu = GLU(dim, dim, nn.SiLU())
|
||||
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
|
||||
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
|
||||
self.swish = nn.SiLU()
|
||||
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
#@compile
|
||||
def forward(self, x):
|
||||
x = self.in_norm(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.glu(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.depthwise_conv(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
x = self.mid_norm(x)
|
||||
x = self.swish(x)
|
||||
x = rearrange(x, 'b n d -> b d n')
|
||||
x = self.pointwise_conv_2(x)
|
||||
x = rearrange(x, 'b d n -> b n d')
|
||||
|
||||
return x
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_heads = 64,
|
||||
cross_attend = False,
|
||||
dim_context = None,
|
||||
global_cond_dim = None,
|
||||
causal = False,
|
||||
zero_init_branch_outputs = True,
|
||||
conformer = False,
|
||||
layer_ix = -1,
|
||||
remove_norms = False,
|
||||
add_rope = False,
|
||||
layer_scale = False,
|
||||
use_sync_block_film = False,
|
||||
attn_kwargs = {},
|
||||
ff_kwargs = {},
|
||||
norm_kwargs = {}
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_heads = min(dim_heads,dim)
|
||||
self.cross_attend = cross_attend
|
||||
self.dim_context = dim_context
|
||||
self.causal = causal
|
||||
if layer_scale and zero_init_branch_outputs:
|
||||
print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False')
|
||||
zero_init_branch_outputs = False
|
||||
|
||||
self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
|
||||
self.add_rope = add_rope
|
||||
|
||||
self.self_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
|
||||
self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.cross_attend = cross_attend
|
||||
if cross_attend:
|
||||
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.cross_attn = Attention(
|
||||
dim,
|
||||
dim_heads = self.dim_heads,
|
||||
dim_context=dim_context,
|
||||
causal = causal,
|
||||
zero_init_output=zero_init_branch_outputs,
|
||||
**attn_kwargs
|
||||
)
|
||||
self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim)
|
||||
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
|
||||
self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.layer_ix = layer_ix
|
||||
|
||||
self.conformer = None
|
||||
if conformer:
|
||||
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs)
|
||||
self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
||||
|
||||
self.global_cond_dim = global_cond_dim
|
||||
if global_cond_dim is not None:
|
||||
self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5)
|
||||
|
||||
self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None
|
||||
|
||||
if use_sync_block_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
|
||||
@compile
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context = None,
|
||||
global_cond=None,
|
||||
rotary_pos_emb = None,
|
||||
self_attention_block_mask = None,
|
||||
self_attention_score_mod = None,
|
||||
cross_attention_block_mask = None,
|
||||
cross_attention_score_mod = None,
|
||||
self_attention_flash_sliding_window = None,
|
||||
cross_attention_flash_sliding_window = None,
|
||||
sync_cond = None,
|
||||
prepend_length=0
|
||||
):
|
||||
if rotary_pos_emb is None and self.add_rope:
|
||||
rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2])
|
||||
|
||||
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1)
|
||||
else:
|
||||
|
||||
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1)
|
||||
|
||||
# self-attention with adaLN
|
||||
residual = x
|
||||
x = self.pre_norm(x)
|
||||
x = x * (1 + scale_self) + shift_self
|
||||
x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)
|
||||
x = x * torch.sigmoid(1 - gate_self)
|
||||
x = self.self_attn_scale(x)
|
||||
x = x + residual
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
# feedforward with adaLN
|
||||
residual = x
|
||||
x = self.ff_norm(x)
|
||||
x = x * (1 + scale_ff) + shift_ff
|
||||
x = self.ff(x)
|
||||
x = x * torch.sigmoid(1 - gate_ff)
|
||||
x = self.ff_scale(x)
|
||||
x = x + residual
|
||||
|
||||
else:
|
||||
x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window))
|
||||
|
||||
if context is not None and self.cross_attend:
|
||||
x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window))
|
||||
|
||||
if self.conformer is not None:
|
||||
x = x + self.conformer_scale(self.conformer(x))
|
||||
|
||||
if sync_cond is not None and hasattr(self, 'sync_film_generator'):
|
||||
prepend_part = x[:, :prepend_length, :]
|
||||
audio_part = x[:, prepend_length:, :]
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
modulated_audio_part = audio_part * (1 + scale) + shift
|
||||
x = torch.cat([prepend_part, modulated_audio_part], dim=1)
|
||||
|
||||
x = x + self.ff_scale(self.ff(self.ff_norm(x)))
|
||||
return x
|
||||
|
||||
class ContinuousTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
*,
|
||||
dim_in = None,
|
||||
dim_out = None,
|
||||
dim_heads = 64,
|
||||
cross_attend=False,
|
||||
cond_token_dim=None,
|
||||
pre_cross_attn_ix=-1,
|
||||
final_cross_attn_ix=-1,
|
||||
global_cond_dim=None,
|
||||
causal=False,
|
||||
rotary_pos_emb=True,
|
||||
zero_init_branch_outputs=True,
|
||||
conformer=False,
|
||||
use_sinusoidal_emb=False,
|
||||
use_abs_pos_emb=False,
|
||||
abs_pos_emb_max_length=10000,
|
||||
num_memory_tokens=0,
|
||||
sliding_window=None,
|
||||
use_mlp=False,
|
||||
use_add_norm=False,
|
||||
use_gated=False,
|
||||
use_final_layer=False,
|
||||
use_zeros=False,
|
||||
use_conv=False,
|
||||
use_fusion_mlp=False,
|
||||
use_film=False,
|
||||
use_sync_film=False,
|
||||
use_sync_gated=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.causal = causal
|
||||
self.layers = nn.ModuleList([])
|
||||
if use_mlp:
|
||||
self.project_in = nn.Sequential(
|
||||
nn.Linear(dim_in, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim, bias=False)
|
||||
)
|
||||
else:
|
||||
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
|
||||
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
|
||||
self.video_temporal_conv = None
|
||||
self.audio_temporal_conv = None
|
||||
self.fusion_mlp = None
|
||||
if use_conv:
|
||||
self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1)
|
||||
if use_fusion_mlp:
|
||||
self.fusion_mlp = nn.Sequential(
|
||||
nn.Linear(dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim)
|
||||
)
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
self.use_sinusoidal_emb = use_sinusoidal_emb
|
||||
if use_sinusoidal_emb:
|
||||
self.pos_emb = ScaledSinusoidalEmbedding(dim)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens)
|
||||
|
||||
self.adaLN_modulation = None
|
||||
if global_cond_dim is not None:
|
||||
if use_final_layer:
|
||||
self.norm_final = LayerNorm(dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
dim, 2 * dim, bias=True
|
||||
),
|
||||
)
|
||||
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.project_out.weight, 0)
|
||||
self.global_cond_embedder = nn.Sequential(
|
||||
nn.Linear(global_cond_dim, dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 6)
|
||||
)
|
||||
if use_zeros:
|
||||
nn.init.constant_(self.global_cond_embedder[-1].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[-1].bias, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].weight, 0)
|
||||
nn.init.constant_(self.global_cond_embedder[0].bias, 0)
|
||||
|
||||
self.final_cross_attn_ix = final_cross_attn_ix
|
||||
self.use_gated = use_gated
|
||||
self.use_film = use_film
|
||||
self.use_add_norm = use_add_norm
|
||||
if self.use_add_norm:
|
||||
self.add_norm = nn.LayerNorm(dim)
|
||||
if use_gated:
|
||||
self.gate = nn.Parameter(torch.ones(1, 1, dim))
|
||||
|
||||
if use_film:
|
||||
self.film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.film_generator = None
|
||||
|
||||
if use_sync_film:
|
||||
self.sync_film_generator = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift
|
||||
)
|
||||
else:
|
||||
self.sync_film_generator = None
|
||||
if use_sync_gated:
|
||||
self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
else:
|
||||
self.sync_gate = None
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
for i in range(depth):
|
||||
should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix))
|
||||
# print(f"Layer {i} cross attends: {should_cross_attend}")
|
||||
self.layers.append(
|
||||
TransformerBlock(
|
||||
dim,
|
||||
dim_heads = dim_heads,
|
||||
cross_attend = should_cross_attend,
|
||||
dim_context = cond_token_dim,
|
||||
global_cond_dim = global_cond_dim,
|
||||
causal = causal,
|
||||
zero_init_branch_outputs = zero_init_branch_outputs,
|
||||
conformer=conformer,
|
||||
layer_ix=i,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
mask = None,
|
||||
prepend_embeds = None,
|
||||
prepend_mask = None,
|
||||
add_cond = None,
|
||||
sync_cond = None,
|
||||
global_cond = None,
|
||||
return_info = False,
|
||||
use_checkpointing = True,
|
||||
exit_layer_ix = None,
|
||||
video_dropout_prob = 0.0,
|
||||
**kwargs
|
||||
):
|
||||
batch, seq, device = *x.shape[:2], x.device
|
||||
model_dtype = next(self.parameters()).dtype
|
||||
x = x.to(model_dtype)
|
||||
|
||||
info = {
|
||||
"hidden_states": [],
|
||||
}
|
||||
|
||||
x = self.project_in(x)
|
||||
if add_cond is not None:
|
||||
if self.use_gated:
|
||||
gate = torch.sigmoid(self.gate)
|
||||
x = x + gate * add_cond
|
||||
elif self.use_film:
|
||||
scale, shift = self.film_generator(add_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
else:
|
||||
x = x + add_cond
|
||||
|
||||
if self.use_add_norm:
|
||||
x = self.add_norm(x)
|
||||
if self.fusion_mlp is not None:
|
||||
x = self.fusion_mlp(x)
|
||||
|
||||
if sync_cond is not None:
|
||||
if self.sync_film_generator is not None:
|
||||
scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1)
|
||||
x = x * (1 + scale) + shift
|
||||
elif self.sync_gate is not None:
|
||||
gate_value = torch.sigmoid(self.sync_gate)
|
||||
x = x + gate_value * sync_cond
|
||||
# else:
|
||||
# x = x + sync_cond
|
||||
|
||||
if prepend_embeds is not None:
|
||||
prepend_length, prepend_dim = prepend_embeds.shape[1:]
|
||||
|
||||
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
|
||||
|
||||
x = torch.cat((prepend_embeds, x), dim = -2)
|
||||
|
||||
if self.num_memory_tokens > 0:
|
||||
memory_tokens = self.memory_tokens.expand(batch, -1, -1)
|
||||
x = torch.cat((memory_tokens, x), dim=1)
|
||||
|
||||
if self.rotary_pos_emb is not None:
|
||||
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
|
||||
else:
|
||||
rotary_pos_emb = None
|
||||
|
||||
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
|
||||
x = x + self.pos_emb(x)
|
||||
|
||||
if global_cond is not None and self.global_cond_embedder is not None:
|
||||
global_cond_embed = self.global_cond_embedder(global_cond)
|
||||
else:
|
||||
global_cond_embed = global_cond
|
||||
# Iterate over the transformer layers
|
||||
for layer_ix, layer in enumerate(self.layers):
|
||||
if use_checkpointing:
|
||||
x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs)
|
||||
|
||||
if return_info:
|
||||
info["hidden_states"].append(x)
|
||||
|
||||
if exit_layer_ix is not None and layer_ix == exit_layer_ix:
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
|
||||
x = x[:, self.num_memory_tokens:, :]
|
||||
if global_cond is not None and self.adaLN_modulation is not None:
|
||||
if len(global_cond.shape) == 2:
|
||||
global_cond = global_cond.unsqueeze(1)
|
||||
shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.project_out(x)
|
||||
|
||||
if return_info:
|
||||
return x, info
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,177 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
|
||||
#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
|
||||
def load_ckpt_state_dict(ckpt_path, prefix=None):
|
||||
if ckpt_path.endswith(".safetensors"):
|
||||
state_dict = load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
|
||||
# 过滤特定前缀的state_dict
|
||||
filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
def remove_weight_norm_from_model(model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "weight"):
|
||||
print(f"Removing weight norm from {module}")
|
||||
remove_weight_norm(module)
|
||||
|
||||
return model
|
||||
|
||||
# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
|
||||
# License can be found in LICENSES/LICENSE_META.txt
|
||||
|
||||
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor containing probabilities.
|
||||
num_samples (int): Number of samples to draw.
|
||||
replacement (bool): Whether to draw with replacement or not.
|
||||
Keywords args:
|
||||
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||
Returns:
|
||||
torch.Tensor: Last dimension contains num_samples indices
|
||||
sampled from the multinomial probability distribution
|
||||
located in the last dimension of tensor input.
|
||||
"""
|
||||
|
||||
if num_samples == 1:
|
||||
q = torch.empty_like(input).exponential_(1, generator=generator)
|
||||
return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
|
||||
|
||||
input_ = input.reshape(-1, input.shape[-1])
|
||||
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||
return output
|
||||
|
||||
|
||||
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
k (int): The k in “top-k”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||
min_value_top_k = top_k_value[..., [-1]]
|
||||
probs *= (probs >= min_value_top_k).float()
|
||||
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs, num_samples=1)
|
||||
return next_token
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
p (int): The p in “top-p”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort *= (~mask).float()
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
def next_power_of_two(n):
|
||||
return 2 ** (n - 1).bit_length()
|
||||
|
||||
def next_multiple_of_64(n):
|
||||
return ((n + 63) // 64) * 64
|
||||
|
||||
|
||||
# mask construction helpers
|
||||
|
||||
def mask_from_start_end_indices(
|
||||
seq_len: int,
|
||||
start: Tensor,
|
||||
end: Tensor
|
||||
):
|
||||
assert start.shape == end.shape
|
||||
device = start.device
|
||||
|
||||
seq = torch.arange(seq_len, device = device, dtype = torch.long)
|
||||
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
|
||||
seq = seq.expand(*start.shape, seq_len)
|
||||
|
||||
mask = seq >= start[..., None].long()
|
||||
mask &= seq < end[..., None].long()
|
||||
return mask
|
||||
|
||||
def mask_from_frac_lengths(
|
||||
seq_len: int,
|
||||
frac_lengths: Tensor
|
||||
):
|
||||
device = frac_lengths.device
|
||||
|
||||
lengths = (frac_lengths * seq_len).long()
|
||||
max_start = seq_len - lengths
|
||||
|
||||
rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
|
||||
start = (max_start * rand).clamp(min = 0)
|
||||
end = start + lengths
|
||||
|
||||
return mask_from_start_end_indices(seq_len, start, end)
|
||||
|
||||
def _build_spline(video_feat, video_t, target_t):
|
||||
# 三次样条插值核心实现
|
||||
coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
|
||||
spline = NaturalCubicSpline(coeffs)
|
||||
return spline.evaluate(target_t).permute(0,2,1)
|
||||
|
||||
def resample(video_feat, audio_latent):
|
||||
"""
|
||||
9s
|
||||
video_feat: [B, 72, D]
|
||||
audio_latent: [B, D', 194] or int
|
||||
"""
|
||||
B, Tv, D = video_feat.shape
|
||||
|
||||
if isinstance(audio_latent, torch.Tensor):
|
||||
# audio_latent is a tensor
|
||||
if audio_latent.shape[1] != 64:
|
||||
Ta = audio_latent.shape[1]
|
||||
else:
|
||||
Ta = audio_latent.shape[2]
|
||||
elif isinstance(audio_latent, int):
|
||||
# audio_latent is an int
|
||||
Ta = audio_latent
|
||||
else:
|
||||
raise TypeError("audio_latent must be either a tensor or an int")
|
||||
|
||||
# 构建时间戳 (关键改进点)
|
||||
video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
|
||||
audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
|
||||
|
||||
# 三维化处理 (Batch, Feature, Time)
|
||||
video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
|
||||
|
||||
# 三次样条插值
|
||||
aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
|
||||
return aligned_video.permute(0, 2, 1) # [B, Ta, D]
|
||||
|
||||
import os
|
||||
enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1"
|
||||
|
||||
def compile(function, *args, **kwargs):
|
||||
|
||||
if enable_torch_compile:
|
||||
try:
|
||||
return torch.compile(function, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
return function
|
||||
|
||||
return function
|
||||
Reference in New Issue
Block a user